mirror of https://github.com/hpcaitech/ColossalAI
[FAW] parallel FreqAwareEmbedding (#1424)
parent
0d212183c4
commit
cb98cf5558
|
@ -195,3 +195,39 @@ def split_forward_gather_backward(input_, process_group, dim):
|
|||
|
||||
def gather_forward_split_backward(input_, process_group, dim):
|
||||
return _GatherForwardSplitBackward.apply(input_, process_group, dim)
|
||||
|
||||
|
||||
def _all_to_all(x: torch.Tensor, pg: ProcessGroup, scatter_dim: int, gather_dim: int) -> torch.Tensor:
|
||||
world_size = pg.tp_world_size()
|
||||
if world_size == 1:
|
||||
return x
|
||||
|
||||
# TODO: enabling mpi backend to support CPU all_to_all
|
||||
assert x.device.type == 'cuda', f"Currently, the collective function dual_all_to_all only supports nccl backend"
|
||||
|
||||
shapes = list(x.size())
|
||||
shapes[scatter_dim] = shapes[scatter_dim] // world_size
|
||||
|
||||
scatter_list = [each.contiguous() for each in torch.tensor_split(x, world_size, scatter_dim)]
|
||||
gather_list = [torch.empty(*shapes, dtype=x.dtype, device=x.device) for _ in range(world_size)]
|
||||
torch.distributed.all_to_all(gather_list, scatter_list, group=pg.tp_process_group())
|
||||
|
||||
return torch.cat(gather_list, dim=gather_dim).contiguous()
|
||||
|
||||
|
||||
class _DualAllToAll(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x, pg, scatter_dim, gather_dim):
|
||||
ctx.scatter_dim = scatter_dim
|
||||
ctx.gather_dim = gather_dim
|
||||
ctx.pg = pg
|
||||
return _all_to_all(x, pg, scatter_dim, gather_dim)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad):
|
||||
return _all_to_all(grad, ctx.pg, ctx.gather_dim, ctx.scatter_dim), None, None, None
|
||||
|
||||
|
||||
def dual_all_to_all(x, pg, scatter_dim: int, gather_dim: int):
|
||||
return _DualAllToAll.apply(x, pg, scatter_dim, gather_dim)
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
from .cache_mgr import CachedParamMgr
|
||||
from .copyer import LimitBuffIndexCopyer
|
||||
from .freq_aware_embedding import FreqAwareEmbeddingBag
|
||||
from .parallel_freq_aware_embedding import ParallelFreqAwareEmbeddingBag
|
||||
|
||||
__all__ = ['CachedParamMgr', 'LimitBuffIndexCopyer', 'FreqAwareEmbeddingBag']
|
||||
__all__ = ['CachedParamMgr', 'LimitBuffIndexCopyer', 'FreqAwareEmbeddingBag', 'ParallelFreqAwareEmbeddingBag']
|
||||
|
|
|
@ -0,0 +1,136 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
from typing import List, Optional, Iterator, Tuple
|
||||
|
||||
from .base_embedding import BaseEmbeddingBag
|
||||
from .cache_mgr import CachedParamMgr
|
||||
from torch.nn.parameter import Parameter
|
||||
from .._utils import dual_all_to_all
|
||||
|
||||
from colossalai.tensor import ColoParameter, ShardSpec, ComputeSpec, ComputePattern, ProcessGroup
|
||||
|
||||
|
||||
def get_partition(embedding_dim, rank, world_size) -> Tuple[int, int, bool]:
|
||||
if world_size == 1:
|
||||
return 0, embedding_dim, True
|
||||
|
||||
assert embedding_dim >= world_size, \
|
||||
f"Embedding dimension {embedding_dim} must be larger than the world size " \
|
||||
f"{world_size} of the process group"
|
||||
chunk_size = embedding_dim // world_size
|
||||
threshold = embedding_dim % world_size
|
||||
# if embedding dim is divisible by world size
|
||||
if threshold == 0:
|
||||
return rank * chunk_size, (rank + 1) * chunk_size, True
|
||||
|
||||
# align with the split strategy of torch.tensor_split
|
||||
size_list = [chunk_size + 1 if i < threshold else chunk_size for i in range(world_size)]
|
||||
offset = sum(size_list[:rank])
|
||||
return offset, offset + size_list[rank], False
|
||||
|
||||
|
||||
class ParallelFreqAwareEmbeddingBag(BaseEmbeddingBag):
|
||||
|
||||
def __init__(self,
|
||||
num_embeddings,
|
||||
embedding_dim,
|
||||
padding_idx=None,
|
||||
max_norm=None,
|
||||
norm_type=2.,
|
||||
scale_grad_by_freq=False,
|
||||
sparse=False,
|
||||
_weight=None,
|
||||
mode='mean',
|
||||
include_last_offset=False,
|
||||
dtype=None,
|
||||
debug=True):
|
||||
super(ParallelFreqAwareEmbeddingBag,
|
||||
self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq,
|
||||
sparse, mode, include_last_offset)
|
||||
|
||||
self.rank = torch.distributed.get_rank()
|
||||
self.world_size = torch.distributed.get_world_size()
|
||||
self.debug = debug
|
||||
|
||||
self.partition_start_index, self.partition_end_index, divisible = get_partition(
|
||||
embedding_dim, self.rank, self.world_size)
|
||||
self.embedding_dim_per_partition = self.partition_end_index - self.partition_start_index
|
||||
|
||||
if _weight is None:
|
||||
self._weight.process_group = ProcessGroup(tp_degree=self.world_size)
|
||||
self._weight = ColoParameter.from_torch_tensor(torch.empty(self.num_embeddings,
|
||||
self.embedding_dim_per_partition,
|
||||
device='cpu',
|
||||
dtype=dtype),
|
||||
requires_grad=True,
|
||||
spec=ShardSpec(dims=[-1], num_partitions=[self.world_size]))
|
||||
self.init_parameters()
|
||||
else:
|
||||
assert isinstance(_weight, ColoParameter), "initialized weight must in type of ColoParameter"
|
||||
_weight.process_group = ProcessGroup(tp_degree=self.world_size)
|
||||
_weight.set_tensor_spec(ShardSpec(dims=[-1], num_partitions=[self.world_size]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
self._weight = _weight
|
||||
|
||||
@property
|
||||
def weight(self):
|
||||
return self.cache_weight_mgr.cpu_weight
|
||||
|
||||
def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, Parameter]]:
|
||||
yield 'weight', self.cache_weight_mgr.cuda_cached_weight
|
||||
|
||||
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
|
||||
yield self.cache_weight_mgr.cuda_cached_weight
|
||||
|
||||
@torch.no_grad()
|
||||
def init_parameters(self):
|
||||
self._weight.data.uniform_(-1 / self.num_embeddings, 1 / self.num_embeddings)
|
||||
if self.padding_idx is not None:
|
||||
self._weight[self.padding_idx].fill_(0)
|
||||
|
||||
def preprocess(self,
|
||||
cuda_row_num: int,
|
||||
ids_freq_mapping: Optional[List[int]] = None,
|
||||
warmup_ratio: float = 0.7,
|
||||
buffer_size: int = 50_000):
|
||||
self.cache_weight_mgr = CachedParamMgr(self._weight, cuda_row_num, buffer_size=buffer_size)
|
||||
self.cache_weight_mgr.reorder(ids_freq_mapping, warmup_ratio)
|
||||
|
||||
def forward(self, indices, offsets=None, per_sample_weights=None, shape_hook=None, scatter_dim=0, gather_dim=-1):
|
||||
with torch.no_grad():
|
||||
reorder_ids = self.cache_weight_mgr.prepare_ids(indices)
|
||||
|
||||
output_shard = F.embedding_bag(reorder_ids, self.cache_weight_mgr.cuda_cached_weight, offsets, self.max_norm,
|
||||
self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse,
|
||||
per_sample_weights, self.include_last_offset, self.padding_idx)
|
||||
|
||||
if shape_hook is not None:
|
||||
output_shard = shape_hook(output_shard)
|
||||
|
||||
output_full = dual_all_to_all(output_shard,
|
||||
self._weight.get_process_group(),
|
||||
scatter_dim=scatter_dim,
|
||||
gather_dim=gather_dim)
|
||||
return output_full
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls,
|
||||
embedding: torch.Tensor,
|
||||
freeze: bool = True,
|
||||
padding_idx: Optional[int] = None,
|
||||
max_norm: Optional[float] = None,
|
||||
norm_type: float = 2.,
|
||||
scale_grad_by_freq: bool = False,
|
||||
sparse: bool = False,
|
||||
mode: str = 'mean',
|
||||
include_last_offset: bool = False,
|
||||
debug: bool = True,
|
||||
cuda_row_num: int = 100_000,
|
||||
ids_freq_mapping: Optional[List[int]] = None,
|
||||
warmup_ratio: float = 0.7) -> 'ParallelFreqAwareEmbeddingBag':
|
||||
rows, cols = embedding.shape
|
||||
embedding_bag = cls(rows, cols, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse, embedding, mode,
|
||||
include_last_offset, debug)
|
||||
embedding_bag.preprocess(cuda_row_num, ids_freq_mapping, warmup_ratio)
|
||||
embedding_bag.cache_weight_mgr.cuda_cached_weight.requires_grad_ = not freeze
|
||||
return embedding_bag
|
|
@ -3,9 +3,13 @@ from functools import partial
|
|||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import numpy as np
|
||||
import random
|
||||
|
||||
import colossalai
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.tensor import ColoParameter
|
||||
from colossalai.nn._ops.cache_embedding import CachedParamMgr, FreqAwareEmbeddingBag, ParallelFreqAwareEmbeddingBag
|
||||
|
||||
from colossalai.nn._ops.cache_embedding import CachedParamMgr, FreqAwareEmbeddingBag
|
||||
|
||||
|
@ -13,6 +17,15 @@ NUM_EMBED, EMBED_DIM = 10, 8
|
|||
BATCH_SIZE = 8
|
||||
|
||||
|
||||
def set_seed(seed):
|
||||
"""
|
||||
To achieve reproducible results, it's necessary to fix random seeds
|
||||
"""
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
|
||||
|
||||
def synthesize_1d_sparse_feature(
|
||||
batch_size,
|
||||
num_embed,
|
||||
|
@ -128,7 +141,91 @@ def test_freq_aware_embed():
|
|||
f"model weight: {model_weight[10:18, :8]}, reference: {ref_weight[10:18, :8]}"
|
||||
|
||||
|
||||
def gather_tensor(tensor, rank, world_size):
|
||||
gather_list = []
|
||||
if rank == 0:
|
||||
gather_list = [torch.empty_like(tensor) for _ in range(world_size)]
|
||||
|
||||
torch.distributed.gather(tensor, gather_list, dst=0)
|
||||
return gather_list
|
||||
|
||||
|
||||
def run_parallel_freq_aware_embed(rank, world_size):
|
||||
device = torch.device('cuda', torch.cuda.current_device())
|
||||
|
||||
num_embed = 100
|
||||
embed_dim = 16
|
||||
batch_size = 4
|
||||
|
||||
set_seed(4321)
|
||||
weight = torch.rand(num_embed, embed_dim)
|
||||
coloweight = ColoParameter(weight.clone().detach().cpu(), requires_grad=False)
|
||||
|
||||
model = ParallelFreqAwareEmbeddingBag.from_pretrained(coloweight,
|
||||
include_last_offset=True,
|
||||
freeze=False,
|
||||
cuda_row_num=batch_size * 2)
|
||||
|
||||
assert model.cache_weight_mgr.cpu_weight.device.type == 'cpu'
|
||||
assert model.cache_weight_mgr.cuda_cached_weight.requires_grad
|
||||
weight_in_rank = torch.tensor_split(weight, world_size, -1)[rank]
|
||||
assert torch.allclose(
|
||||
weight_in_rank,
|
||||
model.cache_weight_mgr.cpu_weight.detach()), f"{weight_in_rank - model.cache_weight_mgr.cpu_weight}"
|
||||
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
|
||||
|
||||
if rank == 0:
|
||||
ref_model = torch.nn.EmbeddingBag.from_pretrained(weight.detach().clone(),
|
||||
include_last_offset=True,
|
||||
freeze=False).to(device)
|
||||
ref_optimizer = torch.optim.SGD(ref_model.parameters(), lr=1e-3)
|
||||
|
||||
set_seed(4321)
|
||||
for i in range(5):
|
||||
indices, offsets = synthesize_1d_sparse_feature(batch_size, num_embed, device)
|
||||
res = model(indices, offsets)
|
||||
|
||||
grad = torch.rand(batch_size * 2, embed_dim, dtype=res.dtype, device=res.device)
|
||||
grad_in_rank = torch.tensor_split(grad, world_size, 0)[rank]
|
||||
res.backward(grad_in_rank)
|
||||
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
res_list = gather_tensor(res.detach(), rank, world_size)
|
||||
|
||||
if rank == 0:
|
||||
ref_res = ref_model(indices, offsets)
|
||||
recover_res = torch.cat(res_list, dim=0)
|
||||
|
||||
assert torch.allclose(ref_res, recover_res)
|
||||
|
||||
ref_res.backward(grad)
|
||||
ref_optimizer.step()
|
||||
ref_optimizer.zero_grad()
|
||||
|
||||
model.cache_weight_mgr.flush()
|
||||
weight_list = gather_tensor(model.cache_weight_mgr.cpu_weight.detach().cuda(), rank, world_size)
|
||||
if rank == 0:
|
||||
recover_weight = torch.cat(weight_list, dim=1)
|
||||
assert torch.allclose(recover_weight, ref_model.weight.detach()), f"{recover_weight - ref_model.weight}"
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_parallel_freq_aware_embed(rank, world_size)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 4])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_parallel_freq_aware_embed(world_size):
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# test_freq_aware_embed()
|
||||
# test_chunkmgr_admit()
|
||||
pass
|
||||
test_parallel_freq_aware_embed(2)
|
||||
|
|
Loading…
Reference in New Issue