Browse Source

add vocabembedding layer

pull/4157/head
FoolPlayer 1 year ago committed by Frank Lee
parent
commit
507c0ad368
  1. 65
      colossalai/shardformer/layer/layers.py
  2. 45
      tests/test_shardformer/test_layer/test_vocab_parallel_embedding_1d.py

65
colossalai/shardformer/layer/layers.py

@ -139,6 +139,7 @@ class Linear1D_Col(ParallelModule):
with self.randomizer.fork_rng(enable_cpu=True):
self.reset_parameters(weight_initializer, bias_initializer)
@staticmethod
def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args,
**kwargs) -> ParallelModule:
r"""
@ -587,6 +588,8 @@ class VocabParallelEmbedding1D(ParallelLayer):
embedding_dim: int,
padding_idx: int = None,
dtype: torch.dtype = None,
device: torch.device = None,
process_group: ProcessGroup = None,
weight_initializer: Callable = init.normal_(),
*args,
**kwargs):
@ -596,21 +599,63 @@ class VocabParallelEmbedding1D(ParallelLayer):
self.padding_idx = padding_idx
self.embed_args = args
self.embed_kwargs = kwargs
self.process_group = process_group
tensor_parallel_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
# self.num_embeddings_per_partition = divide(num_embeddings, tensor_parallel_size)
self.num_embeddings_per_partition = num_embeddings
tensor_parallel_size = dist.get_world_size(group=process_group)
tensor_parallel_rank = dist.get_rank(group=process_group)
self.num_embeddings_per_partition = divide(num_embeddings, tensor_parallel_size)
self.num_embeddings = self.num_embeddings_per_partition
self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition
self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition
self.weight = Parameter(
torch.empty((self.num_embeddings_per_partition, self.embed_dim), device=get_current_device(), dtype=dtype))
torch.empty((self.num_embeddings_per_partition, self.embed_dim), device=device, dtype=dtype))
# offset the seed with randomizer index and rank
seed = torch.random.initial_seed()
self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)
with self.randomizer.fork_rng(enable_cpu=True):
self.reset_parameters(weight_initializer)
# self.reset_parameters(weight_initializer)
# self._set_tensor_parallel_attributes()
# set_parallel_input(False)
# env.vocab_parallel = True
@staticmethod
def from_native_module(module: nn.Embedding, process_group: Union[ProcessGroup, List[ProcessGroup]], *args,
**kwargs) -> ParallelModule:
r"""
Convert a native pytorch embedding module to a parallel module.
"""
# get the origin attributes
num_embeddings = module.num_embeddings
embedding_dim = module.embedding_dim
padding_idx = module.padding_idx
device = module.weight.device
# ensure only one process group is used
if isinstance(process_group, (list, tuple)):
assert len(process_group) == 1, \
f'Expected only one process group, got {len(process_group)}.'
process_group = process_group[0]
# create the parallel module
vocab_embedding_1d = VocabParallelEmbedding1D(num_embeddings=num_embeddings,
embedding_dim=embedding_dim,
padding_idx=padding_idx,
device=device,
process_group=process_group,
*args,
**kwargs)
with torch.no_grad():
# shard and slice the weight along the vocabulary(num_embeddings) dimension
# the shape of the weight is (num_embeddings, embedding_dim)
shard_weight = shard_rowwise(module.weight.data, process_group)
vocab_embedding_1d.weight.data.copy_(shard_weight)
self.reset_parameters(weight_initializer)
self._set_tensor_parallel_attributes()
set_parallel_input(False)
env.vocab_parallel = True
return vocab_embedding_1d
def _set_tensor_parallel_attributes(self):
set_tensor_parallel_attribute_by_partition(self.weight, gpc.tensor_parallel_size)
@ -665,5 +710,5 @@ class VocabParallelEmbedding1D(ParallelLayer):
# Mask the output embedding.
output_parallel[input_mask, :] = 0.
# Reduce across all the model parallel GPUs.
output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D)
output = reduce_input(output_parallel, self.process_group)
return output

45
tests/test_shardformer/test_layer/test_vocab_parallel_embedding_1d.py

@ -0,0 +1,45 @@
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.testing import assert_close
import colossalai
from colossalai.shardformer.layer.layers import VocabParallelEmbedding1D
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
def check_vocab_embedding_1d():
embedding = nn.Embedding(128, 32).to('cuda')
dist_embedding_1d = VocabParallelEmbedding1D.from_native_module(embedding, process_group=None)
assert dist_embedding_1d.weight.shape == torch.Size([64, 32])
assert dist_embedding_1d.num_embeddings == 64
assert dist_embedding_1d.embed_dim == 32
# check embedding correctness
x = torch.randint(0, 128, (4, 32)).to('cuda')
org_out = embedding(x)
dist_out = dist_embedding_1d(x)
assert_close(org_out, dist_out)
# check backward correctness
org_out.sum().backward()
dist_out.sum().backward()
rank = dist.get_rank()
target_grad = torch.chunk(embedding.weight.grad, 2, dim=0)[rank]
assert_close(target_grad, dist_embedding_1d.weight.grad)
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
check_vocab_embedding_1d()
@rerun_if_address_is_in_use()
def test_vocab_embedding():
spawn(run_dist, nprocs=2)
if __name__ == '__main__':
test_vocab_embedding()
Loading…
Cancel
Save