mirror of https://github.com/hpcaitech/ColossalAI
[tensor] derive compute pattern from dist spec (#971)
* derive compute pattern from dist spec * polish codepull/981/head
parent
46bc95708f
commit
c2fdc6a011
|
@ -11,7 +11,7 @@ from colossalai.tensor import dist_spec
|
|||
|
||||
def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Union[int, float],
|
||||
alpha: Union[int, float]) -> ColoTensor:
|
||||
parallel_action = mat2.spec.get_action_by_compute_pattern(ComputePattern.TP1DRow)
|
||||
parallel_action = mat2.spec.get_action_by_compute_pattern(ComputePattern.TP1D)
|
||||
# mat1:S[1] x mat2:S[0] = Output:P
|
||||
# beta * input + alpha * All-Reduce(Output) = res
|
||||
|
||||
|
@ -32,7 +32,7 @@ def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTenso
|
|||
def colo_addmm_1Dcol(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Union[int, float],
|
||||
alpha: Union[int, float]) -> ColoTensor:
|
||||
# mat1:B x mat2:S[1] + input:S[1] = Output:S[1]
|
||||
parallel_action = mat2.spec.get_action_by_compute_pattern(ComputePattern.TP1DCol)
|
||||
parallel_action = mat2.spec.get_action_by_compute_pattern(ComputePattern.TP1D)
|
||||
mat1.to_dist_spec(dist_spec.replicate(mat2.spec.get_process_group()))
|
||||
mat1_torch_tensor = reduce_grad(mat1.torch_tensor(), parallel_action.parallel_mode)
|
||||
|
||||
|
@ -71,16 +71,16 @@ def colo_addmm(types, args, kwargs, pg):
|
|||
# Add communication logic before and after linear call.
|
||||
ret_tensor = None
|
||||
if not mat2.has_spec(): # No Model Parallel Applied
|
||||
assert not input_tensor.has_spec(), 'Invalid input spec for native addmm op'
|
||||
assert mat2.spec.is_gathered(), 'Invalid mat2 spec for native addmm op'
|
||||
assert input_tensor.spec.is_gathered(), 'Invalid input spec for native addmm op'
|
||||
ret_tensor = ColoTensor.init_from_torch_tensor(
|
||||
torch.addbmm(input_tensor.torch_tensor(), mat1, mat2.torch_tensor(), beta=beta, alpha=alpha))
|
||||
elif mat2.spec.num_action == 1: # Single Model Parallel Applied
|
||||
torch.addmm(input_tensor.torch_tensor(), mat1, mat2.torch_tensor(), beta=beta, alpha=alpha))
|
||||
elif mat2.spec.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
|
||||
spec = TensorSpec(dist_spec.replicate(mat2.spec.get_process_group()))
|
||||
mat1 = args[1] if isinstance(args[1], ColoTensor) else ColoTensor.init_from_torch_tensor(args[1], spec=spec)
|
||||
compute_patterns = mat2.spec.compute_patterns
|
||||
if ComputePattern.TP1DRow in compute_patterns:
|
||||
if mat2.spec.is_1D_row() and input_tensor.spec.is_gathered():
|
||||
ret_tensor = colo_addmm_1Drow(input_tensor, mat1, mat2, beta, alpha)
|
||||
elif ComputePattern.TP1DCol in compute_patterns:
|
||||
elif mat2.spec.is_1D_col() and (input_tensor.spec.is_1D_col() or input_tensor.spec.is_1D_row()):
|
||||
ret_tensor = colo_addmm_1Dcol(input_tensor, mat1, mat2, beta, alpha)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -12,7 +12,7 @@ from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, Parall
|
|||
def colo_embedding_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, args, kwargs) -> ColoTensor:
|
||||
# embedding_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P)
|
||||
# Gather splitted lookup table
|
||||
parallel_action = weight.spec.get_action_by_compute_pattern(ComputePattern.TP1DCol)
|
||||
parallel_action = weight.spec.get_action_by_compute_pattern(ComputePattern.TP1D)
|
||||
input_tensor.to_dist_spec(dist_spec.replicate(weight.spec.get_process_group()))
|
||||
|
||||
output_parallel = torch.nn.functional.embedding(input_tensor.torch_tensor(), weight.torch_tensor(), *args, **kwargs)
|
||||
|
@ -28,7 +28,7 @@ def colo_embedding_1Drow(input_tensor: ColoTensor, weight: ColoTensor, args, kwa
|
|||
# embedding_1Drow split the weight(lookup table) to (num_embeddings/P, embedding_dim)
|
||||
# Find index in this shard and mask those not here
|
||||
# Reduce all
|
||||
parallel_action = weight.spec.get_action_by_compute_pattern(ComputePattern.TP1DRow)
|
||||
parallel_action = weight.spec.get_action_by_compute_pattern(ComputePattern.TP1D)
|
||||
input_tensor.to_dist_spec(dist_spec.replicate(weight.spec.get_process_group()))
|
||||
|
||||
tensor_parallel_rank = gpc.get_local_rank(parallel_action.parallel_mode)
|
||||
|
@ -71,16 +71,17 @@ def colo_embedding(types, args, kwargs, pg):
|
|||
weight = ColoTensor.init_from_torch_tensor(weight)
|
||||
|
||||
# Handle differen parallel actions.
|
||||
|
||||
if not weight.has_spec(): # No Model Parallel Applied
|
||||
assert weight.spec.is_gathered(), 'Invalid weight spec for native embedding op'
|
||||
input_tensor = input_tensor.torch_tensor()
|
||||
weight = weight.torch_tensor()
|
||||
output = torch.nn.functional.embedding(input_tensor, weight, *args, **kwargs)
|
||||
return ColoTensor.init_from_torch_tensor(output)
|
||||
elif weight.spec.num_action == 1: # Single Model Parallel Applied
|
||||
compute_patterns = weight.spec.compute_patterns
|
||||
if ComputePattern.TP1DRow in compute_patterns:
|
||||
elif weight.spec.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
|
||||
if weight.spec.is_1D_row():
|
||||
return colo_embedding_1Drow(input_tensor, weight, args, kwargs)
|
||||
elif ComputePattern.TP1DCol in compute_patterns:
|
||||
elif weight.spec.is_1D_col():
|
||||
return colo_embedding_1Dcol(input_tensor, weight, args, kwargs)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -9,7 +9,7 @@ from colossalai.tensor.graph import GraphOpNode, GraphGlobalEnv
|
|||
|
||||
|
||||
def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: ColoTensor) -> ColoTensor:
|
||||
parallel_action = weight.spec.get_action_by_compute_pattern(ComputePattern.TP1DRow)
|
||||
parallel_action = weight.spec.get_action_by_compute_pattern(ComputePattern.TP1D)
|
||||
# Input:S[1] x Weight:S[0] = Output:P
|
||||
# All-Reduce(Output) + bias = res
|
||||
# Input:S[1]
|
||||
|
@ -33,11 +33,12 @@ def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: ColoTe
|
|||
# Input:B x Weight:S[1] + Bias:S[1] = Output:S[1]
|
||||
# All-Gather(Output)
|
||||
# Input:B
|
||||
parallel_action = weight.spec.get_action_by_compute_pattern(ComputePattern.TP1DCol)
|
||||
parallel_action = weight.spec.get_action_by_compute_pattern(ComputePattern.TP1D)
|
||||
input_tensor.to_dist_spec(dist_spec.replicate(weight.spec.get_process_group()))
|
||||
input_parallel = reduce_grad(input_tensor.torch_tensor(), parallel_action.parallel_mode)
|
||||
|
||||
output_parallel = torch.nn.functional.linear(input_parallel, weight.torch_tensor(), bias.torch_tensor())
|
||||
if bias is not None:
|
||||
bias = bias.torch_tensor()
|
||||
output_parallel = torch.nn.functional.linear(input_parallel, weight.torch_tensor(), bias)
|
||||
|
||||
output = ColoTensor.init_from_torch_tensor(
|
||||
output_parallel,
|
||||
|
@ -83,16 +84,17 @@ def colo_linear(types, args, kwargs, pg):
|
|||
# Add communication logic before and after linear call.
|
||||
ret_tensor = None
|
||||
if not weight.has_spec(): # No Model Parallel Applied
|
||||
assert not bias.has_spec(), 'Invalid bias spec for native Linear op'
|
||||
assert bias.spec.is_gathered(), 'Invalid bias spec for native Linear op'
|
||||
assert bias.spec.is_gathered(), 'Invalid bias spec for native Linear op'
|
||||
input_tensor = input_tensor.torch_tensor()
|
||||
weight = weight.torch_tensor()
|
||||
bias = bias.torch_tensor()
|
||||
if bias is not None:
|
||||
bias = bias.torch_tensor()
|
||||
ret_tensor = ColoTensor.init_from_torch_tensor(torch.nn.functional.linear(input_tensor, weight, bias))
|
||||
elif weight.spec.num_action == 1: # Single Model Parallel Applied
|
||||
compute_patterns = weight.spec.compute_patterns
|
||||
if ComputePattern.TP1DRow in compute_patterns:
|
||||
elif weight.spec.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
|
||||
if weight.spec.is_1D_col() and (bias is None or bias.spec.is_gathered()):
|
||||
ret_tensor = colo_linear_1Drow(input_tensor, weight, bias)
|
||||
elif ComputePattern.TP1DCol in compute_patterns:
|
||||
elif weight.spec.is_1D_row() and (bias is None or bias.spec.is_1D_row() or bias.spec.is_1D_col()):
|
||||
ret_tensor = colo_linear_1Dcol(input_tensor, weight, bias)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -4,6 +4,7 @@ from colossalai.tensor.op_wrapper import colo_op_impl
|
|||
from colossalai.tensor import ColoTensor
|
||||
from colossalai.nn.loss.loss_1d import VocabParallelCrossEntropyLoss1D
|
||||
|
||||
|
||||
@colo_op_impl(torch.nn.functional.cross_entropy)
|
||||
def colo_cross_entropy(types, args=(), kwargs=None, pg=None):
|
||||
arg_num = len(args)
|
||||
|
@ -27,13 +28,13 @@ def colo_cross_entropy(types, args=(), kwargs=None, pg=None):
|
|||
if isinstance(target, ColoTensor):
|
||||
target = target.torch_tensor()
|
||||
|
||||
if input_tensor.spec.is_gathered(): # Input is gathered
|
||||
return ColoTensor.init_from_torch_tensor(torch.nn.functional.cross_entropy(
|
||||
input_tensor.torch_tensor(), target, weight))
|
||||
if input_tensor.spec.is_gathered(): # Input is gathered
|
||||
return ColoTensor.init_from_torch_tensor(
|
||||
torch.nn.functional.cross_entropy(input_tensor.torch_tensor(), target, weight))
|
||||
elif input_tensor.has_spec() and input_tensor.spec.num_action == 1: # Single Model Parallel Applied
|
||||
if input_tensor.spec.is_1Dcol():
|
||||
return ColoTensor.init_from_torch_tensor(
|
||||
VocabParallelCrossEntropyLoss1D()(input_tensor.torch_tensor(), target))
|
||||
if input_tensor.spec.is_1D_col():
|
||||
return ColoTensor.init_from_torch_tensor(VocabParallelCrossEntropyLoss1D()(input_tensor.torch_tensor(),
|
||||
target))
|
||||
else:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from enum import Enum
|
||||
from torch.distributed import ProcessGroup
|
||||
from typing import Optional, List
|
||||
from numpy import prod
|
||||
|
||||
__all__ = ['replicate', 'shard']
|
||||
|
||||
|
@ -39,4 +40,5 @@ def shard(process_group: ProcessGroup, dims: List[int], num_partitions: List[int
|
|||
assert process_group is not None
|
||||
assert isinstance(dims, list) and isinstance(num_partitions, list)
|
||||
assert len(dims) == len(num_partitions)
|
||||
assert prod(num_partitions) == process_group.size()
|
||||
return _DistSpec(DistPlacementPattern.SHARD, process_group, dims=tuple(dims), num_partitions=tuple(num_partitions))
|
||||
|
|
|
@ -5,17 +5,9 @@ from colossalai.tensor.dist_spec import _DistSpec, DistPlacementPattern
|
|||
|
||||
|
||||
class ComputePattern(Enum):
|
||||
# TODO (ver217): remove TP1DRow_<ops>
|
||||
TP1DRow = 0
|
||||
TP1DCol = 9
|
||||
TP1DRow_Linear = 1
|
||||
TP1DCol_Linear = 2
|
||||
TP1DRow_Embedding = 3
|
||||
TP1DCol_Embedding = 4
|
||||
TP1DRow_mm = 5
|
||||
TP1DCol_mm = 6
|
||||
ZeRO = 7
|
||||
DP = 8
|
||||
TP1D = 0
|
||||
ZeRO = 1
|
||||
DP = 2
|
||||
|
||||
|
||||
class ParallelAction(object):
|
||||
|
@ -45,14 +37,14 @@ class TensorSpec(object):
|
|||
# using ZeRO with DP-degree = 4 and 1DRowTP with TP-degree = 2.
|
||||
# parallel_action_list = [
|
||||
# ParallelAction(10, ComputePattern.ZeRO, gpc.get_group(ParallelMode.DATA)),
|
||||
# ParallelAction(1, ComputePattern.TP1DRow_Linear, gpc.get_group(ParallelMode.PARALLEL_1D))
|
||||
# ParallelAction(1, ComputePattern.TP1D_Linear, gpc.get_group(ParallelMode.PARALLEL_1D))
|
||||
# ]
|
||||
# When the ColoTensor is initialized,
|
||||
# we first splitting tensor according to ParallelAction of ZeRO,
|
||||
# then splitting tensor according to ParallelAction of TP1DRow_Linear.
|
||||
# then splitting tensor according to ParallelAction of TP1D_Linear.
|
||||
# During Linear computation
|
||||
# Before Linear Op, we gather the tensors according to ZeRO.
|
||||
# We perform Linear Op according to compute pattern of TP1DRow_Linear.
|
||||
# We perform Linear Op according to compute pattern of TP1D_Linear.
|
||||
# After Linear Op, we split the tensors according to ZeRO.
|
||||
|
||||
def __init__(self, dist_spec: _DistSpec, parallel_action_list: List[ParallelAction] = []):
|
||||
|
@ -90,10 +82,17 @@ class TensorSpec(object):
|
|||
|
||||
def is_gathered(self):
|
||||
return self.dist_spec.placement == DistPlacementPattern.REPLICATE \
|
||||
or (len(self.dist_spec.num_partitions) == 1
|
||||
or (len(self.dist_spec.num_partitions) == 1
|
||||
and self.dist_spec.num_partitions[0] == 1) \
|
||||
or (self.dist_spec.process_group.size() == 1)
|
||||
|
||||
def is_1Dcol(self):
|
||||
def is_1D_col(self):
|
||||
return self.dist_spec.placement == DistPlacementPattern.SHARD \
|
||||
and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == -1
|
||||
and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == -1
|
||||
|
||||
def is_1D_row(self):
|
||||
return self.dist_spec.placement == DistPlacementPattern.SHARD \
|
||||
and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == 0
|
||||
|
||||
def has_compute_pattern(self, compute_pattern: ComputePattern):
|
||||
return self.get_action_by_compute_pattern(compute_pattern) is not None
|
||||
|
|
|
@ -40,7 +40,7 @@ class Conv1D(nn.Module):
|
|||
def init_1d_row(weight, bias):
|
||||
spec = TensorSpec(
|
||||
dist_spec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DRow, parallel_mode=ParallelMode.PARALLEL_1D)])
|
||||
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_spec(spec)
|
||||
|
||||
|
@ -55,7 +55,7 @@ def check_grad_1d_row(model: torch.nn.Module, weight, bias):
|
|||
def init_1d_col(weight, bias):
|
||||
spec = TensorSpec(
|
||||
dist_spec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DCol, parallel_mode=ParallelMode.PARALLEL_1D)])
|
||||
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_spec(spec)
|
||||
bias.set_spec(spec)
|
||||
|
|
|
@ -17,7 +17,7 @@ from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, dist_s
|
|||
def init_1d_row(weight):
|
||||
spec = TensorSpec(
|
||||
dist_spec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DRow, parallel_mode=ParallelMode.PARALLEL_1D)])
|
||||
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_spec(spec)
|
||||
|
||||
|
@ -31,7 +31,7 @@ def check_grad_1d_row(model: torch.nn.Module, weight):
|
|||
def init_1d_col(weight):
|
||||
spec = TensorSpec(
|
||||
dist_spec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DCol, parallel_mode=ParallelMode.PARALLEL_1D)])
|
||||
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_spec(spec)
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@ from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, dist_s
|
|||
def init_1d_row(weight, bias):
|
||||
spec = TensorSpec(
|
||||
dist_spec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DRow, parallel_mode=ParallelMode.PARALLEL_1D)])
|
||||
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_spec(spec)
|
||||
|
||||
|
@ -33,7 +33,7 @@ def check_grad_1d_row(model: torch.nn.Module, weight, bias):
|
|||
def init_1d_col(weight, bias):
|
||||
spec = TensorSpec(
|
||||
dist_spec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DCol, parallel_mode=ParallelMode.PARALLEL_1D)])
|
||||
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_spec(spec)
|
||||
bias.set_spec(spec)
|
||||
|
|
|
@ -86,35 +86,43 @@ def set_seed(seed):
|
|||
torch.cuda.manual_seed(seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
|
||||
|
||||
def init_1d_row_linear(weight):
|
||||
spec = TensorSpec(
|
||||
dist_spec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DRow, parallel_mode=ParallelMode.PARALLEL_1D)])
|
||||
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_spec(spec)
|
||||
|
||||
|
||||
def init_1d_col_linear(weight, gather_out=True):
|
||||
spec = TensorSpec(
|
||||
dist_spec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DCol, parallel_mode=ParallelMode.PARALLEL_1D, \
|
||||
gather_out=gather_out)])
|
||||
dist_spec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), [
|
||||
ParallelAction(priority=1,
|
||||
compute_pattern=ComputePattern.TP1D,
|
||||
parallel_mode=ParallelMode.PARALLEL_1D,
|
||||
gather_out=gather_out)
|
||||
])
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_spec(spec)
|
||||
|
||||
|
||||
def init_1d_row_embedding(weight):
|
||||
spec = TensorSpec(
|
||||
dist_spec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DRow, parallel_mode=ParallelMode.PARALLEL_1D)])
|
||||
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_spec(spec)
|
||||
|
||||
|
||||
def init_1d_col_embedding(weight):
|
||||
spec = TensorSpec(
|
||||
dist_spec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DCol, parallel_mode=ParallelMode.PARALLEL_1D)])
|
||||
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_spec(spec)
|
||||
|
||||
|
||||
def run_1d_hybrid_tp(model_name):
|
||||
# A simple net with two stacked nn.Linear
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
|
@ -124,7 +132,7 @@ def run_1d_hybrid_tp(model_name):
|
|||
set_seed(1)
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
model = model_builder(checkpoint=True)
|
||||
|
||||
|
||||
if rank == 0:
|
||||
model_torch = model_builder(checkpoint=True)
|
||||
model_torch = model_torch.cuda()
|
||||
|
@ -173,7 +181,7 @@ def run_1d_hybrid_tp(model_name):
|
|||
if rank == 0:
|
||||
model_torch.eval()
|
||||
colo_optimizer_torch.zero_grad()
|
||||
|
||||
|
||||
data = data.to(get_current_device())
|
||||
label = label.to(get_current_device())
|
||||
|
||||
|
@ -217,11 +225,11 @@ def run_1d_hybrid_tp(model_name):
|
|||
assert torch.allclose(p1, p2)
|
||||
else:
|
||||
# TODO(jzy) Only check 1D spec. Need to be replaced by new DistSpec.
|
||||
if p1.size(-1) < p2.size(-1): # col
|
||||
if p1.size(-1) < p2.size(-1): # col
|
||||
world_size = p2.size(-1) // p1.size(-1)
|
||||
split_p2 = torch.chunk(p2, world_size, dim=-1)[0]
|
||||
|
||||
elif p1.size(0) < p2.size(0): # row
|
||||
|
||||
elif p1.size(0) < p2.size(0): # row
|
||||
world_size = p2.size(0) // p1.size(0)
|
||||
split_p2 = torch.chunk(p2, world_size, dim=0)[0]
|
||||
|
||||
|
@ -376,7 +384,7 @@ def _run_pretrain_load():
|
|||
if isinstance(param, ColoParameter):
|
||||
c1 += 1
|
||||
else:
|
||||
c2 +=1
|
||||
c2 += 1
|
||||
dict_col[name] = param
|
||||
assert c_ref == c1
|
||||
assert c2 == 0
|
||||
|
@ -395,6 +403,7 @@ def run_model_dist(rank, world_size, port):
|
|||
for name in ['bert', 'simple_net']:
|
||||
run_1d_hybrid_tp(name)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 4])
|
||||
# @parameterize('world_size', [1, 4])
|
||||
|
|
Loading…
Reference in New Issue