mirror of https://github.com/hpcaitech/ColossalAI
add vocabembedding layer
parent
45d9384346
commit
507c0ad368
|
@ -139,6 +139,7 @@ class Linear1D_Col(ParallelModule):
|
|||
with self.randomizer.fork_rng(enable_cpu=True):
|
||||
self.reset_parameters(weight_initializer, bias_initializer)
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args,
|
||||
**kwargs) -> ParallelModule:
|
||||
r"""
|
||||
|
@ -587,6 +588,8 @@ class VocabParallelEmbedding1D(ParallelLayer):
|
|||
embedding_dim: int,
|
||||
padding_idx: int = None,
|
||||
dtype: torch.dtype = None,
|
||||
device: torch.device = None,
|
||||
process_group: ProcessGroup = None,
|
||||
weight_initializer: Callable = init.normal_(),
|
||||
*args,
|
||||
**kwargs):
|
||||
|
@ -596,21 +599,63 @@ class VocabParallelEmbedding1D(ParallelLayer):
|
|||
self.padding_idx = padding_idx
|
||||
self.embed_args = args
|
||||
self.embed_kwargs = kwargs
|
||||
self.process_group = process_group
|
||||
|
||||
tensor_parallel_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
|
||||
tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
||||
# self.num_embeddings_per_partition = divide(num_embeddings, tensor_parallel_size)
|
||||
self.num_embeddings_per_partition = num_embeddings
|
||||
tensor_parallel_size = dist.get_world_size(group=process_group)
|
||||
tensor_parallel_rank = dist.get_rank(group=process_group)
|
||||
|
||||
self.num_embeddings_per_partition = divide(num_embeddings, tensor_parallel_size)
|
||||
self.num_embeddings = self.num_embeddings_per_partition
|
||||
self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition
|
||||
self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition
|
||||
|
||||
self.weight = Parameter(
|
||||
torch.empty((self.num_embeddings_per_partition, self.embed_dim), device=get_current_device(), dtype=dtype))
|
||||
torch.empty((self.num_embeddings_per_partition, self.embed_dim), device=device, dtype=dtype))
|
||||
|
||||
# offset the seed with randomizer index and rank
|
||||
seed = torch.random.initial_seed()
|
||||
self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)
|
||||
|
||||
with self.randomizer.fork_rng(enable_cpu=True):
|
||||
self.reset_parameters(weight_initializer)
|
||||
self._set_tensor_parallel_attributes()
|
||||
set_parallel_input(False)
|
||||
env.vocab_parallel = True
|
||||
# self.reset_parameters(weight_initializer)
|
||||
# self._set_tensor_parallel_attributes()
|
||||
# set_parallel_input(False)
|
||||
# env.vocab_parallel = True
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(module: nn.Embedding, process_group: Union[ProcessGroup, List[ProcessGroup]], *args,
|
||||
**kwargs) -> ParallelModule:
|
||||
r"""
|
||||
Convert a native pytorch embedding module to a parallel module.
|
||||
"""
|
||||
# get the origin attributes
|
||||
num_embeddings = module.num_embeddings
|
||||
embedding_dim = module.embedding_dim
|
||||
padding_idx = module.padding_idx
|
||||
device = module.weight.device
|
||||
|
||||
# ensure only one process group is used
|
||||
if isinstance(process_group, (list, tuple)):
|
||||
assert len(process_group) == 1, \
|
||||
f'Expected only one process group, got {len(process_group)}.'
|
||||
process_group = process_group[0]
|
||||
|
||||
# create the parallel module
|
||||
vocab_embedding_1d = VocabParallelEmbedding1D(num_embeddings=num_embeddings,
|
||||
embedding_dim=embedding_dim,
|
||||
padding_idx=padding_idx,
|
||||
device=device,
|
||||
process_group=process_group,
|
||||
*args,
|
||||
**kwargs)
|
||||
with torch.no_grad():
|
||||
# shard and slice the weight along the vocabulary(num_embeddings) dimension
|
||||
# the shape of the weight is (num_embeddings, embedding_dim)
|
||||
shard_weight = shard_rowwise(module.weight.data, process_group)
|
||||
vocab_embedding_1d.weight.data.copy_(shard_weight)
|
||||
|
||||
return vocab_embedding_1d
|
||||
|
||||
def _set_tensor_parallel_attributes(self):
|
||||
set_tensor_parallel_attribute_by_partition(self.weight, gpc.tensor_parallel_size)
|
||||
|
@ -665,5 +710,5 @@ class VocabParallelEmbedding1D(ParallelLayer):
|
|||
# Mask the output embedding.
|
||||
output_parallel[input_mask, :] = 0.
|
||||
# Reduce across all the model parallel GPUs.
|
||||
output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D)
|
||||
output = reduce_input(output_parallel, self.process_group)
|
||||
return output
|
||||
|
|
|
@ -0,0 +1,45 @@
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.testing import assert_close
|
||||
|
||||
import colossalai
|
||||
from colossalai.shardformer.layer.layers import VocabParallelEmbedding1D
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
def check_vocab_embedding_1d():
|
||||
embedding = nn.Embedding(128, 32).to('cuda')
|
||||
dist_embedding_1d = VocabParallelEmbedding1D.from_native_module(embedding, process_group=None)
|
||||
|
||||
assert dist_embedding_1d.weight.shape == torch.Size([64, 32])
|
||||
assert dist_embedding_1d.num_embeddings == 64
|
||||
assert dist_embedding_1d.embed_dim == 32
|
||||
|
||||
# check embedding correctness
|
||||
x = torch.randint(0, 128, (4, 32)).to('cuda')
|
||||
org_out = embedding(x)
|
||||
dist_out = dist_embedding_1d(x)
|
||||
assert_close(org_out, dist_out)
|
||||
|
||||
# check backward correctness
|
||||
org_out.sum().backward()
|
||||
dist_out.sum().backward()
|
||||
|
||||
rank = dist.get_rank()
|
||||
target_grad = torch.chunk(embedding.weight.grad, 2, dim=0)[rank]
|
||||
assert_close(target_grad, dist_embedding_1d.weight.grad)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
check_vocab_embedding_1d()
|
||||
|
||||
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_vocab_embedding():
|
||||
spawn(run_dist, nprocs=2)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_vocab_embedding()
|
Loading…
Reference in New Issue