[Tensor] add embedding tp1d row (#904)

pull/907/head
Ziyue Jiang 3 years ago committed by GitHub
parent 16122d5fac
commit f593a5637e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -9,7 +9,7 @@ from packaging import version
from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ParallelAction, ColoTensor, ShardPattern
def colo_embedding_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, args, kwargs) -> ColoTensor:
# embedding_1Dcol split the weight(lookup table)
# embedding_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P)
# Gather splitted lookup table
parallel_action = weight.shard_spec.get_action_by_compute_pattern(ComputePattern.TP1DCol_Embedding)
if not input_tensor.is_gathered():
@ -25,6 +25,37 @@ def colo_embedding_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, args, kwa
output.gather()
return output
def colo_embedding_1Drow(input_tensor: ColoTensor, weight: ColoTensor, args, kwargs) -> ColoTensor:
# embedding_1Drow split the weight(lookup table) to (num_embeddings/P, embedding_dim)
# Find index in this shard and mask those not here
# Reduce all
parallel_action = weight.shard_spec.get_action_by_compute_pattern(ComputePattern.TP1DRow_Embedding)
if not input_tensor.is_gathered():
input_tensor.gather()
tensor_parallel_rank = gpc.get_local_rank(parallel_action.parallel_mode)
num_embeddings_per_partition = weight.size(0)
vocab_start_index = tensor_parallel_rank * num_embeddings_per_partition
vocab_end_index = vocab_start_index + num_embeddings_per_partition
# Build the mask.
input_mask = (input_tensor.torch_tensor() < vocab_start_index) | \
(input_tensor.torch_tensor() >= vocab_end_index)
# Mask the input.
# TODO(jzy) masked_input may be an activation managed by ColoTensor.
masked_input = input_tensor.torch_tensor().clone() - vocab_start_index
masked_input[input_mask] = 0
partial_output = torch.nn.functional.embedding(masked_input, weight.torch_tensor(),
*args, **kwargs)
# Mask the output embedding.
partial_output[input_mask, :] = 0.
# Reduce across all the model parallel GPUs.
output = reduce_input(partial_output, parallel_action.parallel_mode)
output = ColoTensor.init_from_torch_tensor(output)
return output
@colo_op_impl(torch.nn.functional.embedding)
def colo_embedding(types, args, kwargs, pg):
"""Handles ``__torch_function__`` dispatch for ``torch.nn.functional.embedding``.
@ -48,7 +79,9 @@ def colo_embedding(types, args, kwargs, pg):
return ColoTensor.init_from_torch_tensor(output)
elif weight.shard_spec.num_action == 1: # Single Model Parallel Applied
compute_patterns = weight.shard_spec.compute_patterns
if ComputePattern.TP1DCol_Embedding in compute_patterns:
if ComputePattern.TP1DRow_Embedding in compute_patterns:
return colo_embedding_1Drow(input_tensor, weight, args, kwargs)
elif ComputePattern.TP1DCol_Embedding in compute_patterns:
return colo_embedding_1Dcol(input_tensor, weight, args, kwargs)
else:
raise NotImplementedError

@ -166,6 +166,7 @@ class ColoTensor(object):
dim = -1
self._torch_tensor = gather_forward_split_backward(self._torch_tensor, parallel_action.parallel_mode, dim=dim)
self._shard_pattern = ShardPattern.NA
self._size = self._torch_tensor.size()
def is_gathered(self) -> bool:
return self._shard_pattern == ShardPattern.NA

@ -5,7 +5,6 @@ from .utils.dummy_data_generator import DummyDataGenerator
from .registry import non_distributed_component_funcs
from colossalai.utils.cuda import get_current_device
class SimpleNet(CheckpointModule):
"""
In this no-leaf module, it has subordinate nn.modules and a nn.Parameter.
@ -13,12 +12,14 @@ class SimpleNet(CheckpointModule):
def __init__(self, checkpoint=False) -> None:
super().__init__(checkpoint=checkpoint)
self.embed = nn.Embedding(20, 4)
self.proj1 = nn.Linear(4, 8)
self.ln1 = nn.LayerNorm(8)
self.proj2 = nn.Linear(8, 4)
self.ln2 = nn.LayerNorm(4)
def forward(self, x):
x = self.embed(x)
x = self.proj1(x)
x = self.ln1(x)
x = self.proj2(x)
@ -26,11 +27,12 @@ class SimpleNet(CheckpointModule):
return x
class DummyDataLoader(DummyDataGenerator):
def generate(self):
data = torch.rand(16, 4, device=get_current_device())
label = torch.randint(low=0, high=2, size=(16,), device=get_current_device())
data = torch.randint(low=0, high=20, size=(16,20), device=get_current_device())
label = torch.randint(low=0, high=2, size=(16,4), device=get_current_device())
return data, label

@ -65,10 +65,60 @@ def run_embedding_tp1d_col_test():
W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[local_rank]
check_equal(W_grad, layer.weight.grad)
def run_embedding_tp1d_row_test():
device = get_current_device()
dtype = torch.float32
DEPTH = gpc.get_world_size(ParallelMode.PARALLEL_1D)
num_embeddings = 12
embedding_dim = 32
local_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
layer_master = torch.nn.Embedding(num_embeddings, embedding_dim)
layer = torch.nn.Embedding(num_embeddings, embedding_dim)
A_master = torch.tensor((0,3,6,9), device=device)
A = broadcast_tensor_chunk(A_master, chunk_size=1)
W_shape = (num_embeddings, embedding_dim)
W_master = torch.randn(W_shape, dtype=dtype, device=device)
W = broadcast_tensor_chunk(W_master, chunk_size=1)
W.requires_grad = True
# replace the torch nn.Parameters with ColoTensor
sharded_weight = ColoTensor.init_from_torch_tensor(W)
parallel_action_list = [
ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DRow_Embedding,
parallel_mode=ParallelMode.PARALLEL_1D)
]
spec = TensorSpec(parallel_action_list)
sharded_weight.set_spec(spec) # reshard
replace_parameter_add_grad(layer, sharded_weight)
out = layer(A)
replace_parameter_add_grad(layer_master, W_master)
C_master = layer_master(A_master)
C = C_master.clone()
check_equal(out, C)
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
grad = broadcast_tensor_chunk(grad_master, chunk_size=1)
out.backward(grad)
grad_master = grad_master.clone()
C_master.backward(grad_master)
W_grad = W_master.grad
W_grad = torch.chunk(W_grad, DEPTH, dim=0)[local_rank]
check_equal(W_grad, layer.weight.grad)
def run_dist(rank, world_size, port):
config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_embedding_tp1d_col_test()
run_embedding_tp1d_row_test()
@pytest.mark.dist
@parameterize('world_size', [1, 4])

@ -47,6 +47,11 @@ def run_1d_col_tp():
]
spec_col = TensorSpec(parallel_action_list_col)
parallel_action_list_embedding_col = [
ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DCol_Embedding, parallel_mode=ParallelMode.PARALLEL_1D)
]
spec_embedding_col = TensorSpec(parallel_action_list_embedding_col)
set_seed(1)
if rank == 0:
model_torch = model_builder(checkpoint=True)
@ -60,6 +65,8 @@ def run_1d_col_tp():
p.set_spec(spec_col)
if 'proj2' in name and 'weight' in name:
p.set_spec(spec_row)
if 'embed' in name and 'weight' in name:
p.set_spec(spec_embedding_col)
model = model.cuda()
@ -172,6 +179,11 @@ def run_1d_row_tp():
]
spec = TensorSpec(parallel_action_list)
parallel_action_list_embedding_row = [
ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DRow_Embedding, parallel_mode=ParallelMode.PARALLEL_1D)
]
spec_embedding_row = TensorSpec(parallel_action_list_embedding_row)
set_seed(1)
if rank == 0:
model_torch = model_builder(checkpoint=True)
@ -183,6 +195,8 @@ def run_1d_row_tp():
continue
if 'weight' in name and 'LayerNorm' not in name and 'ln' not in name and 'embed' not in name:
p.set_spec(spec)
if 'embed' in name and 'weight' in name:
p.set_spec(spec_embedding_row)
model = model.cuda()
@ -227,7 +241,7 @@ def run_dist(rank, world_size, port):
config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_1d_row_tp()
run_1d_col_tp()
@pytest.mark.dist
@parameterize('world_size', [1, 4])
@ -238,6 +252,6 @@ def test_simple_net(world_size):
if __name__ == '__main__':
# test_simple_net()
test_simple_net()
# test_model_parameters()
test_colo_optimizer()
# test_colo_optimizer()

Loading…
Cancel
Save