mirror of https://github.com/hpcaitech/ColossalAI
[tensor] derive compute pattern from dist spec (#971)
* derive compute pattern from dist spec * polish codepull/981/head
parent
46bc95708f
commit
c2fdc6a011
|
@ -11,7 +11,7 @@ from colossalai.tensor import dist_spec
|
||||||
|
|
||||||
def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Union[int, float],
|
def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Union[int, float],
|
||||||
alpha: Union[int, float]) -> ColoTensor:
|
alpha: Union[int, float]) -> ColoTensor:
|
||||||
parallel_action = mat2.spec.get_action_by_compute_pattern(ComputePattern.TP1DRow)
|
parallel_action = mat2.spec.get_action_by_compute_pattern(ComputePattern.TP1D)
|
||||||
# mat1:S[1] x mat2:S[0] = Output:P
|
# mat1:S[1] x mat2:S[0] = Output:P
|
||||||
# beta * input + alpha * All-Reduce(Output) = res
|
# beta * input + alpha * All-Reduce(Output) = res
|
||||||
|
|
||||||
|
@ -32,7 +32,7 @@ def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTenso
|
||||||
def colo_addmm_1Dcol(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Union[int, float],
|
def colo_addmm_1Dcol(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Union[int, float],
|
||||||
alpha: Union[int, float]) -> ColoTensor:
|
alpha: Union[int, float]) -> ColoTensor:
|
||||||
# mat1:B x mat2:S[1] + input:S[1] = Output:S[1]
|
# mat1:B x mat2:S[1] + input:S[1] = Output:S[1]
|
||||||
parallel_action = mat2.spec.get_action_by_compute_pattern(ComputePattern.TP1DCol)
|
parallel_action = mat2.spec.get_action_by_compute_pattern(ComputePattern.TP1D)
|
||||||
mat1.to_dist_spec(dist_spec.replicate(mat2.spec.get_process_group()))
|
mat1.to_dist_spec(dist_spec.replicate(mat2.spec.get_process_group()))
|
||||||
mat1_torch_tensor = reduce_grad(mat1.torch_tensor(), parallel_action.parallel_mode)
|
mat1_torch_tensor = reduce_grad(mat1.torch_tensor(), parallel_action.parallel_mode)
|
||||||
|
|
||||||
|
@ -71,16 +71,16 @@ def colo_addmm(types, args, kwargs, pg):
|
||||||
# Add communication logic before and after linear call.
|
# Add communication logic before and after linear call.
|
||||||
ret_tensor = None
|
ret_tensor = None
|
||||||
if not mat2.has_spec(): # No Model Parallel Applied
|
if not mat2.has_spec(): # No Model Parallel Applied
|
||||||
assert not input_tensor.has_spec(), 'Invalid input spec for native addmm op'
|
assert mat2.spec.is_gathered(), 'Invalid mat2 spec for native addmm op'
|
||||||
|
assert input_tensor.spec.is_gathered(), 'Invalid input spec for native addmm op'
|
||||||
ret_tensor = ColoTensor.init_from_torch_tensor(
|
ret_tensor = ColoTensor.init_from_torch_tensor(
|
||||||
torch.addbmm(input_tensor.torch_tensor(), mat1, mat2.torch_tensor(), beta=beta, alpha=alpha))
|
torch.addmm(input_tensor.torch_tensor(), mat1, mat2.torch_tensor(), beta=beta, alpha=alpha))
|
||||||
elif mat2.spec.num_action == 1: # Single Model Parallel Applied
|
elif mat2.spec.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
|
||||||
spec = TensorSpec(dist_spec.replicate(mat2.spec.get_process_group()))
|
spec = TensorSpec(dist_spec.replicate(mat2.spec.get_process_group()))
|
||||||
mat1 = args[1] if isinstance(args[1], ColoTensor) else ColoTensor.init_from_torch_tensor(args[1], spec=spec)
|
mat1 = args[1] if isinstance(args[1], ColoTensor) else ColoTensor.init_from_torch_tensor(args[1], spec=spec)
|
||||||
compute_patterns = mat2.spec.compute_patterns
|
if mat2.spec.is_1D_row() and input_tensor.spec.is_gathered():
|
||||||
if ComputePattern.TP1DRow in compute_patterns:
|
|
||||||
ret_tensor = colo_addmm_1Drow(input_tensor, mat1, mat2, beta, alpha)
|
ret_tensor = colo_addmm_1Drow(input_tensor, mat1, mat2, beta, alpha)
|
||||||
elif ComputePattern.TP1DCol in compute_patterns:
|
elif mat2.spec.is_1D_col() and (input_tensor.spec.is_1D_col() or input_tensor.spec.is_1D_row()):
|
||||||
ret_tensor = colo_addmm_1Dcol(input_tensor, mat1, mat2, beta, alpha)
|
ret_tensor = colo_addmm_1Dcol(input_tensor, mat1, mat2, beta, alpha)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
|
@ -12,7 +12,7 @@ from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, Parall
|
||||||
def colo_embedding_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, args, kwargs) -> ColoTensor:
|
def colo_embedding_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, args, kwargs) -> ColoTensor:
|
||||||
# embedding_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P)
|
# embedding_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P)
|
||||||
# Gather splitted lookup table
|
# Gather splitted lookup table
|
||||||
parallel_action = weight.spec.get_action_by_compute_pattern(ComputePattern.TP1DCol)
|
parallel_action = weight.spec.get_action_by_compute_pattern(ComputePattern.TP1D)
|
||||||
input_tensor.to_dist_spec(dist_spec.replicate(weight.spec.get_process_group()))
|
input_tensor.to_dist_spec(dist_spec.replicate(weight.spec.get_process_group()))
|
||||||
|
|
||||||
output_parallel = torch.nn.functional.embedding(input_tensor.torch_tensor(), weight.torch_tensor(), *args, **kwargs)
|
output_parallel = torch.nn.functional.embedding(input_tensor.torch_tensor(), weight.torch_tensor(), *args, **kwargs)
|
||||||
|
@ -28,7 +28,7 @@ def colo_embedding_1Drow(input_tensor: ColoTensor, weight: ColoTensor, args, kwa
|
||||||
# embedding_1Drow split the weight(lookup table) to (num_embeddings/P, embedding_dim)
|
# embedding_1Drow split the weight(lookup table) to (num_embeddings/P, embedding_dim)
|
||||||
# Find index in this shard and mask those not here
|
# Find index in this shard and mask those not here
|
||||||
# Reduce all
|
# Reduce all
|
||||||
parallel_action = weight.spec.get_action_by_compute_pattern(ComputePattern.TP1DRow)
|
parallel_action = weight.spec.get_action_by_compute_pattern(ComputePattern.TP1D)
|
||||||
input_tensor.to_dist_spec(dist_spec.replicate(weight.spec.get_process_group()))
|
input_tensor.to_dist_spec(dist_spec.replicate(weight.spec.get_process_group()))
|
||||||
|
|
||||||
tensor_parallel_rank = gpc.get_local_rank(parallel_action.parallel_mode)
|
tensor_parallel_rank = gpc.get_local_rank(parallel_action.parallel_mode)
|
||||||
|
@ -71,16 +71,17 @@ def colo_embedding(types, args, kwargs, pg):
|
||||||
weight = ColoTensor.init_from_torch_tensor(weight)
|
weight = ColoTensor.init_from_torch_tensor(weight)
|
||||||
|
|
||||||
# Handle differen parallel actions.
|
# Handle differen parallel actions.
|
||||||
|
|
||||||
if not weight.has_spec(): # No Model Parallel Applied
|
if not weight.has_spec(): # No Model Parallel Applied
|
||||||
|
assert weight.spec.is_gathered(), 'Invalid weight spec for native embedding op'
|
||||||
input_tensor = input_tensor.torch_tensor()
|
input_tensor = input_tensor.torch_tensor()
|
||||||
weight = weight.torch_tensor()
|
weight = weight.torch_tensor()
|
||||||
output = torch.nn.functional.embedding(input_tensor, weight, *args, **kwargs)
|
output = torch.nn.functional.embedding(input_tensor, weight, *args, **kwargs)
|
||||||
return ColoTensor.init_from_torch_tensor(output)
|
return ColoTensor.init_from_torch_tensor(output)
|
||||||
elif weight.spec.num_action == 1: # Single Model Parallel Applied
|
elif weight.spec.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
|
||||||
compute_patterns = weight.spec.compute_patterns
|
if weight.spec.is_1D_row():
|
||||||
if ComputePattern.TP1DRow in compute_patterns:
|
|
||||||
return colo_embedding_1Drow(input_tensor, weight, args, kwargs)
|
return colo_embedding_1Drow(input_tensor, weight, args, kwargs)
|
||||||
elif ComputePattern.TP1DCol in compute_patterns:
|
elif weight.spec.is_1D_col():
|
||||||
return colo_embedding_1Dcol(input_tensor, weight, args, kwargs)
|
return colo_embedding_1Dcol(input_tensor, weight, args, kwargs)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
|
@ -9,7 +9,7 @@ from colossalai.tensor.graph import GraphOpNode, GraphGlobalEnv
|
||||||
|
|
||||||
|
|
||||||
def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: ColoTensor) -> ColoTensor:
|
def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: ColoTensor) -> ColoTensor:
|
||||||
parallel_action = weight.spec.get_action_by_compute_pattern(ComputePattern.TP1DRow)
|
parallel_action = weight.spec.get_action_by_compute_pattern(ComputePattern.TP1D)
|
||||||
# Input:S[1] x Weight:S[0] = Output:P
|
# Input:S[1] x Weight:S[0] = Output:P
|
||||||
# All-Reduce(Output) + bias = res
|
# All-Reduce(Output) + bias = res
|
||||||
# Input:S[1]
|
# Input:S[1]
|
||||||
|
@ -33,11 +33,12 @@ def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: ColoTe
|
||||||
# Input:B x Weight:S[1] + Bias:S[1] = Output:S[1]
|
# Input:B x Weight:S[1] + Bias:S[1] = Output:S[1]
|
||||||
# All-Gather(Output)
|
# All-Gather(Output)
|
||||||
# Input:B
|
# Input:B
|
||||||
parallel_action = weight.spec.get_action_by_compute_pattern(ComputePattern.TP1DCol)
|
parallel_action = weight.spec.get_action_by_compute_pattern(ComputePattern.TP1D)
|
||||||
input_tensor.to_dist_spec(dist_spec.replicate(weight.spec.get_process_group()))
|
input_tensor.to_dist_spec(dist_spec.replicate(weight.spec.get_process_group()))
|
||||||
input_parallel = reduce_grad(input_tensor.torch_tensor(), parallel_action.parallel_mode)
|
input_parallel = reduce_grad(input_tensor.torch_tensor(), parallel_action.parallel_mode)
|
||||||
|
if bias is not None:
|
||||||
output_parallel = torch.nn.functional.linear(input_parallel, weight.torch_tensor(), bias.torch_tensor())
|
bias = bias.torch_tensor()
|
||||||
|
output_parallel = torch.nn.functional.linear(input_parallel, weight.torch_tensor(), bias)
|
||||||
|
|
||||||
output = ColoTensor.init_from_torch_tensor(
|
output = ColoTensor.init_from_torch_tensor(
|
||||||
output_parallel,
|
output_parallel,
|
||||||
|
@ -83,16 +84,17 @@ def colo_linear(types, args, kwargs, pg):
|
||||||
# Add communication logic before and after linear call.
|
# Add communication logic before and after linear call.
|
||||||
ret_tensor = None
|
ret_tensor = None
|
||||||
if not weight.has_spec(): # No Model Parallel Applied
|
if not weight.has_spec(): # No Model Parallel Applied
|
||||||
assert not bias.has_spec(), 'Invalid bias spec for native Linear op'
|
assert bias.spec.is_gathered(), 'Invalid bias spec for native Linear op'
|
||||||
|
assert bias.spec.is_gathered(), 'Invalid bias spec for native Linear op'
|
||||||
input_tensor = input_tensor.torch_tensor()
|
input_tensor = input_tensor.torch_tensor()
|
||||||
weight = weight.torch_tensor()
|
weight = weight.torch_tensor()
|
||||||
bias = bias.torch_tensor()
|
if bias is not None:
|
||||||
|
bias = bias.torch_tensor()
|
||||||
ret_tensor = ColoTensor.init_from_torch_tensor(torch.nn.functional.linear(input_tensor, weight, bias))
|
ret_tensor = ColoTensor.init_from_torch_tensor(torch.nn.functional.linear(input_tensor, weight, bias))
|
||||||
elif weight.spec.num_action == 1: # Single Model Parallel Applied
|
elif weight.spec.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
|
||||||
compute_patterns = weight.spec.compute_patterns
|
if weight.spec.is_1D_col() and (bias is None or bias.spec.is_gathered()):
|
||||||
if ComputePattern.TP1DRow in compute_patterns:
|
|
||||||
ret_tensor = colo_linear_1Drow(input_tensor, weight, bias)
|
ret_tensor = colo_linear_1Drow(input_tensor, weight, bias)
|
||||||
elif ComputePattern.TP1DCol in compute_patterns:
|
elif weight.spec.is_1D_row() and (bias is None or bias.spec.is_1D_row() or bias.spec.is_1D_col()):
|
||||||
ret_tensor = colo_linear_1Dcol(input_tensor, weight, bias)
|
ret_tensor = colo_linear_1Dcol(input_tensor, weight, bias)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
|
@ -4,6 +4,7 @@ from colossalai.tensor.op_wrapper import colo_op_impl
|
||||||
from colossalai.tensor import ColoTensor
|
from colossalai.tensor import ColoTensor
|
||||||
from colossalai.nn.loss.loss_1d import VocabParallelCrossEntropyLoss1D
|
from colossalai.nn.loss.loss_1d import VocabParallelCrossEntropyLoss1D
|
||||||
|
|
||||||
|
|
||||||
@colo_op_impl(torch.nn.functional.cross_entropy)
|
@colo_op_impl(torch.nn.functional.cross_entropy)
|
||||||
def colo_cross_entropy(types, args=(), kwargs=None, pg=None):
|
def colo_cross_entropy(types, args=(), kwargs=None, pg=None):
|
||||||
arg_num = len(args)
|
arg_num = len(args)
|
||||||
|
@ -27,13 +28,13 @@ def colo_cross_entropy(types, args=(), kwargs=None, pg=None):
|
||||||
if isinstance(target, ColoTensor):
|
if isinstance(target, ColoTensor):
|
||||||
target = target.torch_tensor()
|
target = target.torch_tensor()
|
||||||
|
|
||||||
if input_tensor.spec.is_gathered(): # Input is gathered
|
if input_tensor.spec.is_gathered(): # Input is gathered
|
||||||
return ColoTensor.init_from_torch_tensor(torch.nn.functional.cross_entropy(
|
return ColoTensor.init_from_torch_tensor(
|
||||||
input_tensor.torch_tensor(), target, weight))
|
torch.nn.functional.cross_entropy(input_tensor.torch_tensor(), target, weight))
|
||||||
elif input_tensor.has_spec() and input_tensor.spec.num_action == 1: # Single Model Parallel Applied
|
elif input_tensor.has_spec() and input_tensor.spec.num_action == 1: # Single Model Parallel Applied
|
||||||
if input_tensor.spec.is_1Dcol():
|
if input_tensor.spec.is_1D_col():
|
||||||
return ColoTensor.init_from_torch_tensor(
|
return ColoTensor.init_from_torch_tensor(VocabParallelCrossEntropyLoss1D()(input_tensor.torch_tensor(),
|
||||||
VocabParallelCrossEntropyLoss1D()(input_tensor.torch_tensor(), target))
|
target))
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
|
from numpy import prod
|
||||||
|
|
||||||
__all__ = ['replicate', 'shard']
|
__all__ = ['replicate', 'shard']
|
||||||
|
|
||||||
|
@ -39,4 +40,5 @@ def shard(process_group: ProcessGroup, dims: List[int], num_partitions: List[int
|
||||||
assert process_group is not None
|
assert process_group is not None
|
||||||
assert isinstance(dims, list) and isinstance(num_partitions, list)
|
assert isinstance(dims, list) and isinstance(num_partitions, list)
|
||||||
assert len(dims) == len(num_partitions)
|
assert len(dims) == len(num_partitions)
|
||||||
|
assert prod(num_partitions) == process_group.size()
|
||||||
return _DistSpec(DistPlacementPattern.SHARD, process_group, dims=tuple(dims), num_partitions=tuple(num_partitions))
|
return _DistSpec(DistPlacementPattern.SHARD, process_group, dims=tuple(dims), num_partitions=tuple(num_partitions))
|
||||||
|
|
|
@ -5,17 +5,9 @@ from colossalai.tensor.dist_spec import _DistSpec, DistPlacementPattern
|
||||||
|
|
||||||
|
|
||||||
class ComputePattern(Enum):
|
class ComputePattern(Enum):
|
||||||
# TODO (ver217): remove TP1DRow_<ops>
|
TP1D = 0
|
||||||
TP1DRow = 0
|
ZeRO = 1
|
||||||
TP1DCol = 9
|
DP = 2
|
||||||
TP1DRow_Linear = 1
|
|
||||||
TP1DCol_Linear = 2
|
|
||||||
TP1DRow_Embedding = 3
|
|
||||||
TP1DCol_Embedding = 4
|
|
||||||
TP1DRow_mm = 5
|
|
||||||
TP1DCol_mm = 6
|
|
||||||
ZeRO = 7
|
|
||||||
DP = 8
|
|
||||||
|
|
||||||
|
|
||||||
class ParallelAction(object):
|
class ParallelAction(object):
|
||||||
|
@ -45,14 +37,14 @@ class TensorSpec(object):
|
||||||
# using ZeRO with DP-degree = 4 and 1DRowTP with TP-degree = 2.
|
# using ZeRO with DP-degree = 4 and 1DRowTP with TP-degree = 2.
|
||||||
# parallel_action_list = [
|
# parallel_action_list = [
|
||||||
# ParallelAction(10, ComputePattern.ZeRO, gpc.get_group(ParallelMode.DATA)),
|
# ParallelAction(10, ComputePattern.ZeRO, gpc.get_group(ParallelMode.DATA)),
|
||||||
# ParallelAction(1, ComputePattern.TP1DRow_Linear, gpc.get_group(ParallelMode.PARALLEL_1D))
|
# ParallelAction(1, ComputePattern.TP1D_Linear, gpc.get_group(ParallelMode.PARALLEL_1D))
|
||||||
# ]
|
# ]
|
||||||
# When the ColoTensor is initialized,
|
# When the ColoTensor is initialized,
|
||||||
# we first splitting tensor according to ParallelAction of ZeRO,
|
# we first splitting tensor according to ParallelAction of ZeRO,
|
||||||
# then splitting tensor according to ParallelAction of TP1DRow_Linear.
|
# then splitting tensor according to ParallelAction of TP1D_Linear.
|
||||||
# During Linear computation
|
# During Linear computation
|
||||||
# Before Linear Op, we gather the tensors according to ZeRO.
|
# Before Linear Op, we gather the tensors according to ZeRO.
|
||||||
# We perform Linear Op according to compute pattern of TP1DRow_Linear.
|
# We perform Linear Op according to compute pattern of TP1D_Linear.
|
||||||
# After Linear Op, we split the tensors according to ZeRO.
|
# After Linear Op, we split the tensors according to ZeRO.
|
||||||
|
|
||||||
def __init__(self, dist_spec: _DistSpec, parallel_action_list: List[ParallelAction] = []):
|
def __init__(self, dist_spec: _DistSpec, parallel_action_list: List[ParallelAction] = []):
|
||||||
|
@ -94,6 +86,13 @@ class TensorSpec(object):
|
||||||
and self.dist_spec.num_partitions[0] == 1) \
|
and self.dist_spec.num_partitions[0] == 1) \
|
||||||
or (self.dist_spec.process_group.size() == 1)
|
or (self.dist_spec.process_group.size() == 1)
|
||||||
|
|
||||||
def is_1Dcol(self):
|
def is_1D_col(self):
|
||||||
return self.dist_spec.placement == DistPlacementPattern.SHARD \
|
return self.dist_spec.placement == DistPlacementPattern.SHARD \
|
||||||
and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == -1
|
and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == -1
|
||||||
|
|
||||||
|
def is_1D_row(self):
|
||||||
|
return self.dist_spec.placement == DistPlacementPattern.SHARD \
|
||||||
|
and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == 0
|
||||||
|
|
||||||
|
def has_compute_pattern(self, compute_pattern: ComputePattern):
|
||||||
|
return self.get_action_by_compute_pattern(compute_pattern) is not None
|
||||||
|
|
|
@ -40,7 +40,7 @@ class Conv1D(nn.Module):
|
||||||
def init_1d_row(weight, bias):
|
def init_1d_row(weight, bias):
|
||||||
spec = TensorSpec(
|
spec = TensorSpec(
|
||||||
dist_spec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
dist_spec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||||
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DRow, parallel_mode=ParallelMode.PARALLEL_1D)])
|
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
|
||||||
with DistSpecManager.no_grad():
|
with DistSpecManager.no_grad():
|
||||||
weight.set_spec(spec)
|
weight.set_spec(spec)
|
||||||
|
|
||||||
|
@ -55,7 +55,7 @@ def check_grad_1d_row(model: torch.nn.Module, weight, bias):
|
||||||
def init_1d_col(weight, bias):
|
def init_1d_col(weight, bias):
|
||||||
spec = TensorSpec(
|
spec = TensorSpec(
|
||||||
dist_spec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
dist_spec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||||
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DCol, parallel_mode=ParallelMode.PARALLEL_1D)])
|
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
|
||||||
with DistSpecManager.no_grad():
|
with DistSpecManager.no_grad():
|
||||||
weight.set_spec(spec)
|
weight.set_spec(spec)
|
||||||
bias.set_spec(spec)
|
bias.set_spec(spec)
|
||||||
|
|
|
@ -17,7 +17,7 @@ from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, dist_s
|
||||||
def init_1d_row(weight):
|
def init_1d_row(weight):
|
||||||
spec = TensorSpec(
|
spec = TensorSpec(
|
||||||
dist_spec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
dist_spec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||||
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DRow, parallel_mode=ParallelMode.PARALLEL_1D)])
|
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
|
||||||
with DistSpecManager.no_grad():
|
with DistSpecManager.no_grad():
|
||||||
weight.set_spec(spec)
|
weight.set_spec(spec)
|
||||||
|
|
||||||
|
@ -31,7 +31,7 @@ def check_grad_1d_row(model: torch.nn.Module, weight):
|
||||||
def init_1d_col(weight):
|
def init_1d_col(weight):
|
||||||
spec = TensorSpec(
|
spec = TensorSpec(
|
||||||
dist_spec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
dist_spec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||||
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DCol, parallel_mode=ParallelMode.PARALLEL_1D)])
|
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
|
||||||
with DistSpecManager.no_grad():
|
with DistSpecManager.no_grad():
|
||||||
weight.set_spec(spec)
|
weight.set_spec(spec)
|
||||||
|
|
||||||
|
|
|
@ -18,7 +18,7 @@ from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, dist_s
|
||||||
def init_1d_row(weight, bias):
|
def init_1d_row(weight, bias):
|
||||||
spec = TensorSpec(
|
spec = TensorSpec(
|
||||||
dist_spec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
dist_spec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||||
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DRow, parallel_mode=ParallelMode.PARALLEL_1D)])
|
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
|
||||||
with DistSpecManager.no_grad():
|
with DistSpecManager.no_grad():
|
||||||
weight.set_spec(spec)
|
weight.set_spec(spec)
|
||||||
|
|
||||||
|
@ -33,7 +33,7 @@ def check_grad_1d_row(model: torch.nn.Module, weight, bias):
|
||||||
def init_1d_col(weight, bias):
|
def init_1d_col(weight, bias):
|
||||||
spec = TensorSpec(
|
spec = TensorSpec(
|
||||||
dist_spec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
dist_spec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||||
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DCol, parallel_mode=ParallelMode.PARALLEL_1D)])
|
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
|
||||||
with DistSpecManager.no_grad():
|
with DistSpecManager.no_grad():
|
||||||
weight.set_spec(spec)
|
weight.set_spec(spec)
|
||||||
bias.set_spec(spec)
|
bias.set_spec(spec)
|
||||||
|
|
|
@ -86,35 +86,43 @@ def set_seed(seed):
|
||||||
torch.cuda.manual_seed(seed)
|
torch.cuda.manual_seed(seed)
|
||||||
torch.backends.cudnn.deterministic = True
|
torch.backends.cudnn.deterministic = True
|
||||||
|
|
||||||
|
|
||||||
def init_1d_row_linear(weight):
|
def init_1d_row_linear(weight):
|
||||||
spec = TensorSpec(
|
spec = TensorSpec(
|
||||||
dist_spec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
dist_spec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||||
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DRow, parallel_mode=ParallelMode.PARALLEL_1D)])
|
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
|
||||||
with DistSpecManager.no_grad():
|
with DistSpecManager.no_grad():
|
||||||
weight.set_spec(spec)
|
weight.set_spec(spec)
|
||||||
|
|
||||||
|
|
||||||
def init_1d_col_linear(weight, gather_out=True):
|
def init_1d_col_linear(weight, gather_out=True):
|
||||||
spec = TensorSpec(
|
spec = TensorSpec(
|
||||||
dist_spec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
dist_spec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), [
|
||||||
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DCol, parallel_mode=ParallelMode.PARALLEL_1D, \
|
ParallelAction(priority=1,
|
||||||
gather_out=gather_out)])
|
compute_pattern=ComputePattern.TP1D,
|
||||||
|
parallel_mode=ParallelMode.PARALLEL_1D,
|
||||||
|
gather_out=gather_out)
|
||||||
|
])
|
||||||
with DistSpecManager.no_grad():
|
with DistSpecManager.no_grad():
|
||||||
weight.set_spec(spec)
|
weight.set_spec(spec)
|
||||||
|
|
||||||
|
|
||||||
def init_1d_row_embedding(weight):
|
def init_1d_row_embedding(weight):
|
||||||
spec = TensorSpec(
|
spec = TensorSpec(
|
||||||
dist_spec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
dist_spec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||||
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DRow, parallel_mode=ParallelMode.PARALLEL_1D)])
|
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
|
||||||
with DistSpecManager.no_grad():
|
with DistSpecManager.no_grad():
|
||||||
weight.set_spec(spec)
|
weight.set_spec(spec)
|
||||||
|
|
||||||
|
|
||||||
def init_1d_col_embedding(weight):
|
def init_1d_col_embedding(weight):
|
||||||
spec = TensorSpec(
|
spec = TensorSpec(
|
||||||
dist_spec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
dist_spec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||||
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DCol, parallel_mode=ParallelMode.PARALLEL_1D)])
|
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
|
||||||
with DistSpecManager.no_grad():
|
with DistSpecManager.no_grad():
|
||||||
weight.set_spec(spec)
|
weight.set_spec(spec)
|
||||||
|
|
||||||
|
|
||||||
def run_1d_hybrid_tp(model_name):
|
def run_1d_hybrid_tp(model_name):
|
||||||
# A simple net with two stacked nn.Linear
|
# A simple net with two stacked nn.Linear
|
||||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||||
|
@ -217,11 +225,11 @@ def run_1d_hybrid_tp(model_name):
|
||||||
assert torch.allclose(p1, p2)
|
assert torch.allclose(p1, p2)
|
||||||
else:
|
else:
|
||||||
# TODO(jzy) Only check 1D spec. Need to be replaced by new DistSpec.
|
# TODO(jzy) Only check 1D spec. Need to be replaced by new DistSpec.
|
||||||
if p1.size(-1) < p2.size(-1): # col
|
if p1.size(-1) < p2.size(-1): # col
|
||||||
world_size = p2.size(-1) // p1.size(-1)
|
world_size = p2.size(-1) // p1.size(-1)
|
||||||
split_p2 = torch.chunk(p2, world_size, dim=-1)[0]
|
split_p2 = torch.chunk(p2, world_size, dim=-1)[0]
|
||||||
|
|
||||||
elif p1.size(0) < p2.size(0): # row
|
elif p1.size(0) < p2.size(0): # row
|
||||||
world_size = p2.size(0) // p1.size(0)
|
world_size = p2.size(0) // p1.size(0)
|
||||||
split_p2 = torch.chunk(p2, world_size, dim=0)[0]
|
split_p2 = torch.chunk(p2, world_size, dim=0)[0]
|
||||||
|
|
||||||
|
@ -376,7 +384,7 @@ def _run_pretrain_load():
|
||||||
if isinstance(param, ColoParameter):
|
if isinstance(param, ColoParameter):
|
||||||
c1 += 1
|
c1 += 1
|
||||||
else:
|
else:
|
||||||
c2 +=1
|
c2 += 1
|
||||||
dict_col[name] = param
|
dict_col[name] = param
|
||||||
assert c_ref == c1
|
assert c_ref == c1
|
||||||
assert c2 == 0
|
assert c2 == 0
|
||||||
|
@ -395,6 +403,7 @@ def run_model_dist(rank, world_size, port):
|
||||||
for name in ['bert', 'simple_net']:
|
for name in ['bert', 'simple_net']:
|
||||||
run_1d_hybrid_tp(name)
|
run_1d_hybrid_tp(name)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@pytest.mark.parametrize('world_size', [1, 4])
|
@pytest.mark.parametrize('world_size', [1, 4])
|
||||||
# @parameterize('world_size', [1, 4])
|
# @parameterize('world_size', [1, 4])
|
||||||
|
|
Loading…
Reference in New Issue