mirror of https://github.com/hpcaitech/ColossalAI
[tensor] refactor parallel action (#1007)
* refactor parallel action * polish unit testspull/1010/head
parent
9e3d602dba
commit
a3b66f6def
|
@ -3,12 +3,12 @@ from colossalai.tensor.op_wrapper import colo_op_impl
|
|||
from colossalai.nn.layer.parallel_1d._utils import reduce_input, reduce_grad
|
||||
from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ParallelAction, ColoTensor
|
||||
from colossalai.tensor import distspec
|
||||
from colossalai.context import ParallelMode
|
||||
from ._utils import GeneralTensor, Number, convert_to_colo_tensor
|
||||
|
||||
|
||||
def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Number,
|
||||
alpha: Number) -> ColoTensor:
|
||||
parallel_action = mat2.spec.get_action_by_compute_pattern(ComputePattern.TP1D)
|
||||
# mat1:S[1] x mat2:S[0] = Output:P
|
||||
# beta * input + alpha * All-Reduce(Output) = res
|
||||
|
||||
|
@ -18,7 +18,7 @@ def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTenso
|
|||
# Output:P
|
||||
partial_output = torch.mm(mat1, mat2)
|
||||
# Reduce(Output)
|
||||
output = reduce_input(partial_output, parallel_action.parallel_mode)
|
||||
output = reduce_input(partial_output, ParallelMode.PARALLEL_1D)
|
||||
# input
|
||||
assert not input_tensor.has_spec(), 'Invalid input spec for 1Drow addmm op'
|
||||
output = beta * input_tensor + alpha * output
|
||||
|
@ -29,13 +29,13 @@ def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTenso
|
|||
def colo_addmm_1Dcol(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Number,
|
||||
alpha: Number) -> ColoTensor:
|
||||
# mat1:B x mat2:S[1] + input:S[1] = Output:S[1]
|
||||
parallel_action = mat2.spec.get_action_by_compute_pattern(ComputePattern.TP1D)
|
||||
parallel_action = mat2.spec.parallel_action
|
||||
mat1 = mat1.convert_to_dist_spec(distspec.replicate(mat2.spec.get_process_group()))
|
||||
mat1 = reduce_grad(mat1, parallel_action.parallel_mode)
|
||||
mat1 = reduce_grad(mat1, ParallelMode.PARALLEL_1D)
|
||||
|
||||
output_parallel = torch.addmm(input_tensor, mat1, mat2, beta=beta, alpha=alpha)
|
||||
output_spec = TensorSpec(distspec.shard(mat2.spec.get_process_group(), [-1], [mat2.spec.get_process_group_size()]),
|
||||
[ParallelAction(priority=1, parallel_mode=parallel_action.parallel_mode)])
|
||||
ParallelAction(ComputePattern.TP1D))
|
||||
output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)
|
||||
if parallel_action.gather_out:
|
||||
# All-Gather(Output)
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
from typing import Optional
|
||||
from colossalai.tensor.op_wrapper import colo_op_impl
|
||||
from colossalai.nn.layer.parallel_1d._utils import reduce_input
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ParallelAction, ColoTensor, distspec
|
||||
from colossalai.context import ParallelMode
|
||||
from ._utils import GeneralTensor, convert_to_colo_tensor
|
||||
|
||||
|
||||
|
@ -17,7 +17,6 @@ def colo_embedding_1Dcol(input_tensor: ColoTensor,
|
|||
sparse: bool = False) -> ColoTensor:
|
||||
# embedding_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P)
|
||||
# Gather splitted lookup table
|
||||
parallel_action = weight.spec.get_action_by_compute_pattern(ComputePattern.TP1D)
|
||||
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.spec.get_process_group()))
|
||||
|
||||
output_parallel = F.embedding(input_tensor,
|
||||
|
@ -29,7 +28,7 @@ def colo_embedding_1Dcol(input_tensor: ColoTensor,
|
|||
sparse=sparse)
|
||||
output_spec = TensorSpec(
|
||||
distspec.shard(weight.spec.get_process_group(), [-1], [weight.spec.get_process_group_size()]),
|
||||
[ParallelAction(priority=1, parallel_mode=parallel_action.parallel_mode)])
|
||||
ParallelAction(ComputePattern.TP1D))
|
||||
output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)
|
||||
output = output.convert_to_dist_spec(distspec.replicate(weight.spec.get_process_group()))
|
||||
return output
|
||||
|
@ -45,10 +44,9 @@ def colo_embedding_1Drow(input_tensor: ColoTensor,
|
|||
# embedding_1Drow split the weight(lookup table) to (num_embeddings/P, embedding_dim)
|
||||
# Find index in this shard and mask those not here
|
||||
# Reduce all
|
||||
parallel_action = weight.spec.get_action_by_compute_pattern(ComputePattern.TP1D)
|
||||
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.spec.get_process_group()))
|
||||
|
||||
tensor_parallel_rank = gpc.get_local_rank(parallel_action.parallel_mode)
|
||||
tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
||||
num_embeddings_per_partition = weight.size(0)
|
||||
vocab_start_index = tensor_parallel_rank * num_embeddings_per_partition
|
||||
vocab_end_index = vocab_start_index + num_embeddings_per_partition
|
||||
|
@ -72,7 +70,7 @@ def colo_embedding_1Drow(input_tensor: ColoTensor,
|
|||
# Mask the output embedding.
|
||||
partial_output[input_mask, :] = 0.
|
||||
# Reduce across all the model parallel GPUs.
|
||||
output = reduce_input(partial_output, parallel_action.parallel_mode)
|
||||
output = reduce_input(partial_output, ParallelMode.PARALLEL_1D)
|
||||
output = ColoTensor.from_torch_tensor(output, spec=TensorSpec(distspec.replicate(weight.spec.get_process_group())))
|
||||
return output
|
||||
|
||||
|
|
|
@ -1,16 +1,14 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.distributed as dist
|
||||
from typing import Optional
|
||||
from colossalai.tensor.op_wrapper import colo_op_impl
|
||||
from colossalai.nn.layer.parallel_1d._utils import reduce_input, reduce_grad
|
||||
from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ParallelAction, ColoTensor, distspec
|
||||
from colossalai.tensor.graph import GraphOpNode, GraphGlobalEnv
|
||||
from colossalai.context import ParallelMode
|
||||
from ._utils import GeneralTensor, convert_to_colo_tensor
|
||||
|
||||
|
||||
def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> ColoTensor:
|
||||
parallel_action = weight.spec.get_action_by_compute_pattern(ComputePattern.TP1D)
|
||||
# Input:S[1] x Weight:S[0] = Output:P
|
||||
# All-Reduce(Output) + bias = res
|
||||
# Input:S[1]
|
||||
|
@ -20,7 +18,7 @@ def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
|
|||
# Output:P
|
||||
partial_output = F.linear(input_tensor, weight)
|
||||
# Reduce(Output)
|
||||
output = reduce_input(partial_output, parallel_action.parallel_mode)
|
||||
output = reduce_input(partial_output, ParallelMode.PARALLEL_1D)
|
||||
# Bias
|
||||
if bias is not None:
|
||||
assert not bias.has_spec(), 'Invalid bias spec for 1Drow Linear op'
|
||||
|
@ -34,15 +32,16 @@ def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
|
|||
# Input:B x Weight:S[1] + Bias:S[1] = Output:S[1]
|
||||
# All-Gather(Output)
|
||||
# Input:B
|
||||
parallel_action = weight.spec.get_action_by_compute_pattern(ComputePattern.TP1D)
|
||||
parallel_action = weight.spec.parallel_action
|
||||
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.spec.get_process_group()))
|
||||
input_parallel = reduce_grad(input_tensor, parallel_action.parallel_mode)
|
||||
input_parallel = reduce_grad(input_tensor, ParallelMode.PARALLEL_1D)
|
||||
|
||||
output_parallel = F.linear(input_parallel, weight, bias)
|
||||
output = ColoTensor.from_torch_tensor(
|
||||
output_parallel,
|
||||
spec=TensorSpec(distspec.shard(weight.spec.get_process_group(), [-1], [weight.spec.get_process_group_size()]),
|
||||
[ParallelAction(priority=1, parallel_mode=parallel_action.parallel_mode)]))
|
||||
output = ColoTensor.from_torch_tensor(output_parallel,
|
||||
spec=TensorSpec(
|
||||
distspec.shard(weight.spec.get_process_group(), [-1],
|
||||
[weight.spec.get_process_group_size()]),
|
||||
ParallelAction(ComputePattern.TP1D)))
|
||||
if parallel_action.gather_out:
|
||||
# All-Gather(Output)
|
||||
output = output.convert_to_dist_spec(distspec.replicate(weight.spec.get_process_group()))
|
||||
|
|
|
@ -28,7 +28,7 @@ def colo_cross_entropy(input_tensor: GeneralTensor,
|
|||
reduction=reduction,
|
||||
label_smoothing=label_smoothing)
|
||||
return ColoTensor.from_torch_tensor(output)
|
||||
elif input_tensor.has_spec() and input_tensor.spec.num_action == 1: # Single Model Parallel Applied
|
||||
elif input_tensor.has_spec(): # Single Model Parallel Applied
|
||||
if input_tensor.spec.is_1D_col():
|
||||
output = VocabParallelCrossEntropyLoss1D()(input_tensor, target)
|
||||
return ColoTensor.from_torch_tensor(output)
|
||||
|
|
|
@ -33,3 +33,6 @@ class ColoParameter(ColoTensor):
|
|||
tensor = tensor.as_subclass(ColoParameter)
|
||||
tensor.__init__(tensor, requires_grad=requires_grad, spec=spec)
|
||||
return tensor
|
||||
|
||||
def __repr__(self):
|
||||
return f'ColoParameter: {torch.Tensor.__repr__(self)}'
|
||||
|
|
|
@ -45,7 +45,7 @@ class ColoTensor(torch.Tensor):
|
|||
self._spec = spec
|
||||
|
||||
def has_spec(self) -> bool:
|
||||
return self._spec.num_action > 0
|
||||
return self._spec.parallel_action is not None
|
||||
|
||||
def is_model_data(self) -> bool:
|
||||
return self._type == TensorType.MODEL
|
||||
|
|
|
@ -1,26 +1,21 @@
|
|||
import torch.distributed as dist
|
||||
from enum import Enum
|
||||
from typing import List
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from typing import List, Optional
|
||||
from colossalai.tensor.distspec import _DistSpec, DistPlacementPattern
|
||||
|
||||
|
||||
class ComputePattern(Enum):
|
||||
TP1D = 0
|
||||
ZeRO = 1
|
||||
DP = 2
|
||||
TP2D = 1
|
||||
TP2P5D = 2
|
||||
TP3D = 3
|
||||
|
||||
|
||||
class ParallelAction(object):
|
||||
|
||||
def __init__(self,
|
||||
priority=0,
|
||||
compute_pattern=ComputePattern.DP,
|
||||
parallel_mode=ParallelMode.DATA,
|
||||
gather_out=True) -> None:
|
||||
self.priority = priority
|
||||
def __init__(self, compute_pattern: ComputePattern, gather_out: bool = True) -> None:
|
||||
assert isinstance(compute_pattern, ComputePattern)
|
||||
self.compute_pattern = compute_pattern
|
||||
self.parallel_mode = parallel_mode
|
||||
self.gather_out = gather_out
|
||||
|
||||
|
||||
|
@ -48,32 +43,9 @@ class TensorSpec(object):
|
|||
# We perform Linear Op according to compute pattern of TP1D_Linear.
|
||||
# After Linear Op, we split the tensors according to ZeRO.
|
||||
|
||||
def __init__(self, dist_spec: _DistSpec, parallel_action_list: List[ParallelAction] = []):
|
||||
self._parallel_action_list = parallel_action_list
|
||||
def __init__(self, dist_spec: _DistSpec, parallel_action: Optional[ParallelAction] = None):
|
||||
self.parallel_action = parallel_action
|
||||
self.dist_spec = dist_spec
|
||||
self.sort()
|
||||
|
||||
@property
|
||||
def parallel_action_list(self):
|
||||
return self._parallel_action_list
|
||||
|
||||
@property
|
||||
def num_action(self):
|
||||
return len(self._parallel_action_list)
|
||||
|
||||
@property
|
||||
def compute_patterns(self):
|
||||
return [parallel_action.compute_pattern for parallel_action in self._parallel_action_list]
|
||||
|
||||
def sort(self):
|
||||
if len(self._parallel_action_list) > 0:
|
||||
self._parallel_action_list.sort(key=lambda parallel_action: parallel_action.priority)
|
||||
|
||||
def get_action_by_compute_pattern(self, compute_pattern: ComputePattern):
|
||||
for parallel_action in self._parallel_action_list:
|
||||
if parallel_action.compute_pattern == compute_pattern:
|
||||
return parallel_action
|
||||
return None
|
||||
|
||||
def get_process_group(self):
|
||||
return self.dist_spec.process_group
|
||||
|
@ -99,4 +71,4 @@ class TensorSpec(object):
|
|||
and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == 0
|
||||
|
||||
def has_compute_pattern(self, compute_pattern: ComputePattern):
|
||||
return self.get_action_by_compute_pattern(compute_pattern) is not None
|
||||
return self.parallel_action.compute_pattern == compute_pattern
|
||||
|
|
|
@ -41,7 +41,7 @@ class Conv1D(nn.Module):
|
|||
def init_1d_row(weight, bias):
|
||||
spec = TensorSpec(
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
|
||||
ParallelAction(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_spec(spec)
|
||||
|
||||
|
@ -49,7 +49,7 @@ def init_1d_row(weight, bias):
|
|||
def init_1d_col(weight, bias):
|
||||
spec = TensorSpec(
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
|
||||
ParallelAction(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_spec(spec)
|
||||
bias.set_spec(spec)
|
||||
|
|
|
@ -18,7 +18,7 @@ from _utils import tensor_equal, tensor_shard_equal
|
|||
def init_1d_row(weight):
|
||||
spec = TensorSpec(
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
|
||||
ParallelAction(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_spec(spec)
|
||||
|
||||
|
@ -26,7 +26,7 @@ def init_1d_row(weight):
|
|||
def init_1d_col(weight):
|
||||
spec = TensorSpec(
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
|
||||
ParallelAction(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_spec(spec)
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@ from tests.components_to_test.registry import non_distributed_component_funcs
|
|||
def init_1d_row_spec(model):
|
||||
spec = TensorSpec(
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
|
||||
ParallelAction(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
for n, p in model.named_parameters():
|
||||
if 'weight' in n and 'ln' not in n:
|
||||
|
@ -26,7 +26,7 @@ def init_1d_row_spec(model):
|
|||
def init_1d_col_spec(model):
|
||||
spec = TensorSpec(
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
|
||||
ParallelAction(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
for n, p in model.named_parameters():
|
||||
if 'ln' not in n and ('weight' in n or 'bias' in n):
|
||||
|
|
|
@ -19,7 +19,7 @@ from _utils import tensor_equal, tensor_shard_equal
|
|||
def init_1d_row(weight, bias):
|
||||
spec = TensorSpec(
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
|
||||
ParallelAction(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_spec(spec)
|
||||
|
||||
|
@ -27,7 +27,7 @@ def init_1d_row(weight, bias):
|
|||
def init_1d_col(weight, bias):
|
||||
spec = TensorSpec(
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
|
||||
ParallelAction(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_spec(spec)
|
||||
bias.set_spec(spec)
|
||||
|
|
|
@ -20,19 +20,15 @@ from _utils import set_seed
|
|||
def init_1d_row_linear(weight):
|
||||
spec = TensorSpec(
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
|
||||
ParallelAction(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_spec(spec)
|
||||
|
||||
|
||||
def init_1d_col_linear(weight, gather_out=True):
|
||||
spec = TensorSpec(
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), [
|
||||
ParallelAction(priority=1,
|
||||
compute_pattern=ComputePattern.TP1D,
|
||||
parallel_mode=ParallelMode.PARALLEL_1D,
|
||||
gather_out=gather_out)
|
||||
])
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
ParallelAction(ComputePattern.TP1D, gather_out=gather_out))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_spec(spec)
|
||||
|
||||
|
@ -40,7 +36,7 @@ def init_1d_col_linear(weight, gather_out=True):
|
|||
def init_1d_row_embedding(weight):
|
||||
spec = TensorSpec(
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
|
||||
ParallelAction(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_spec(spec)
|
||||
|
||||
|
@ -48,7 +44,7 @@ def init_1d_row_embedding(weight):
|
|||
def init_1d_col_embedding(weight):
|
||||
spec = TensorSpec(
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
|
||||
ParallelAction(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_spec(spec)
|
||||
|
||||
|
|
Loading…
Reference in New Issue