[FAW] parallel FreqAwareEmbedding (#1424)

pull/1430/head^2
Jiarui Fang 2022-08-10 13:44:30 +08:00 committed by GitHub
parent 0d212183c4
commit cb98cf5558
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 272 additions and 2 deletions

View File

@ -195,3 +195,39 @@ def split_forward_gather_backward(input_, process_group, dim):
def gather_forward_split_backward(input_, process_group, dim): def gather_forward_split_backward(input_, process_group, dim):
return _GatherForwardSplitBackward.apply(input_, process_group, dim) return _GatherForwardSplitBackward.apply(input_, process_group, dim)
def _all_to_all(x: torch.Tensor, pg: ProcessGroup, scatter_dim: int, gather_dim: int) -> torch.Tensor:
world_size = pg.tp_world_size()
if world_size == 1:
return x
# TODO: enabling mpi backend to support CPU all_to_all
assert x.device.type == 'cuda', f"Currently, the collective function dual_all_to_all only supports nccl backend"
shapes = list(x.size())
shapes[scatter_dim] = shapes[scatter_dim] // world_size
scatter_list = [each.contiguous() for each in torch.tensor_split(x, world_size, scatter_dim)]
gather_list = [torch.empty(*shapes, dtype=x.dtype, device=x.device) for _ in range(world_size)]
torch.distributed.all_to_all(gather_list, scatter_list, group=pg.tp_process_group())
return torch.cat(gather_list, dim=gather_dim).contiguous()
class _DualAllToAll(torch.autograd.Function):
@staticmethod
def forward(ctx, x, pg, scatter_dim, gather_dim):
ctx.scatter_dim = scatter_dim
ctx.gather_dim = gather_dim
ctx.pg = pg
return _all_to_all(x, pg, scatter_dim, gather_dim)
@staticmethod
def backward(ctx, grad):
return _all_to_all(grad, ctx.pg, ctx.gather_dim, ctx.scatter_dim), None, None, None
def dual_all_to_all(x, pg, scatter_dim: int, gather_dim: int):
return _DualAllToAll.apply(x, pg, scatter_dim, gather_dim)

View File

@ -1,5 +1,6 @@
from .cache_mgr import CachedParamMgr from .cache_mgr import CachedParamMgr
from .copyer import LimitBuffIndexCopyer from .copyer import LimitBuffIndexCopyer
from .freq_aware_embedding import FreqAwareEmbeddingBag from .freq_aware_embedding import FreqAwareEmbeddingBag
from .parallel_freq_aware_embedding import ParallelFreqAwareEmbeddingBag
__all__ = ['CachedParamMgr', 'LimitBuffIndexCopyer', 'FreqAwareEmbeddingBag'] __all__ = ['CachedParamMgr', 'LimitBuffIndexCopyer', 'FreqAwareEmbeddingBag', 'ParallelFreqAwareEmbeddingBag']

View File

@ -0,0 +1,136 @@
import torch
import torch.nn.functional as F
from typing import List, Optional, Iterator, Tuple
from .base_embedding import BaseEmbeddingBag
from .cache_mgr import CachedParamMgr
from torch.nn.parameter import Parameter
from .._utils import dual_all_to_all
from colossalai.tensor import ColoParameter, ShardSpec, ComputeSpec, ComputePattern, ProcessGroup
def get_partition(embedding_dim, rank, world_size) -> Tuple[int, int, bool]:
if world_size == 1:
return 0, embedding_dim, True
assert embedding_dim >= world_size, \
f"Embedding dimension {embedding_dim} must be larger than the world size " \
f"{world_size} of the process group"
chunk_size = embedding_dim // world_size
threshold = embedding_dim % world_size
# if embedding dim is divisible by world size
if threshold == 0:
return rank * chunk_size, (rank + 1) * chunk_size, True
# align with the split strategy of torch.tensor_split
size_list = [chunk_size + 1 if i < threshold else chunk_size for i in range(world_size)]
offset = sum(size_list[:rank])
return offset, offset + size_list[rank], False
class ParallelFreqAwareEmbeddingBag(BaseEmbeddingBag):
def __init__(self,
num_embeddings,
embedding_dim,
padding_idx=None,
max_norm=None,
norm_type=2.,
scale_grad_by_freq=False,
sparse=False,
_weight=None,
mode='mean',
include_last_offset=False,
dtype=None,
debug=True):
super(ParallelFreqAwareEmbeddingBag,
self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq,
sparse, mode, include_last_offset)
self.rank = torch.distributed.get_rank()
self.world_size = torch.distributed.get_world_size()
self.debug = debug
self.partition_start_index, self.partition_end_index, divisible = get_partition(
embedding_dim, self.rank, self.world_size)
self.embedding_dim_per_partition = self.partition_end_index - self.partition_start_index
if _weight is None:
self._weight.process_group = ProcessGroup(tp_degree=self.world_size)
self._weight = ColoParameter.from_torch_tensor(torch.empty(self.num_embeddings,
self.embedding_dim_per_partition,
device='cpu',
dtype=dtype),
requires_grad=True,
spec=ShardSpec(dims=[-1], num_partitions=[self.world_size]))
self.init_parameters()
else:
assert isinstance(_weight, ColoParameter), "initialized weight must in type of ColoParameter"
_weight.process_group = ProcessGroup(tp_degree=self.world_size)
_weight.set_tensor_spec(ShardSpec(dims=[-1], num_partitions=[self.world_size]),
ComputeSpec(ComputePattern.TP1D))
self._weight = _weight
@property
def weight(self):
return self.cache_weight_mgr.cpu_weight
def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, Parameter]]:
yield 'weight', self.cache_weight_mgr.cuda_cached_weight
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
yield self.cache_weight_mgr.cuda_cached_weight
@torch.no_grad()
def init_parameters(self):
self._weight.data.uniform_(-1 / self.num_embeddings, 1 / self.num_embeddings)
if self.padding_idx is not None:
self._weight[self.padding_idx].fill_(0)
def preprocess(self,
cuda_row_num: int,
ids_freq_mapping: Optional[List[int]] = None,
warmup_ratio: float = 0.7,
buffer_size: int = 50_000):
self.cache_weight_mgr = CachedParamMgr(self._weight, cuda_row_num, buffer_size=buffer_size)
self.cache_weight_mgr.reorder(ids_freq_mapping, warmup_ratio)
def forward(self, indices, offsets=None, per_sample_weights=None, shape_hook=None, scatter_dim=0, gather_dim=-1):
with torch.no_grad():
reorder_ids = self.cache_weight_mgr.prepare_ids(indices)
output_shard = F.embedding_bag(reorder_ids, self.cache_weight_mgr.cuda_cached_weight, offsets, self.max_norm,
self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse,
per_sample_weights, self.include_last_offset, self.padding_idx)
if shape_hook is not None:
output_shard = shape_hook(output_shard)
output_full = dual_all_to_all(output_shard,
self._weight.get_process_group(),
scatter_dim=scatter_dim,
gather_dim=gather_dim)
return output_full
@classmethod
def from_pretrained(cls,
embedding: torch.Tensor,
freeze: bool = True,
padding_idx: Optional[int] = None,
max_norm: Optional[float] = None,
norm_type: float = 2.,
scale_grad_by_freq: bool = False,
sparse: bool = False,
mode: str = 'mean',
include_last_offset: bool = False,
debug: bool = True,
cuda_row_num: int = 100_000,
ids_freq_mapping: Optional[List[int]] = None,
warmup_ratio: float = 0.7) -> 'ParallelFreqAwareEmbeddingBag':
rows, cols = embedding.shape
embedding_bag = cls(rows, cols, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse, embedding, mode,
include_last_offset, debug)
embedding_bag.preprocess(cuda_row_num, ids_freq_mapping, warmup_ratio)
embedding_bag.cache_weight_mgr.cuda_cached_weight.requires_grad_ = not freeze
return embedding_bag

View File

@ -3,9 +3,13 @@ from functools import partial
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import numpy as np import numpy as np
import random
import colossalai
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use
from colossalai.tensor import ColoParameter
from colossalai.nn._ops.cache_embedding import CachedParamMgr, FreqAwareEmbeddingBag, ParallelFreqAwareEmbeddingBag
from colossalai.nn._ops.cache_embedding import CachedParamMgr, FreqAwareEmbeddingBag from colossalai.nn._ops.cache_embedding import CachedParamMgr, FreqAwareEmbeddingBag
@ -13,6 +17,15 @@ NUM_EMBED, EMBED_DIM = 10, 8
BATCH_SIZE = 8 BATCH_SIZE = 8
def set_seed(seed):
"""
To achieve reproducible results, it's necessary to fix random seeds
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
def synthesize_1d_sparse_feature( def synthesize_1d_sparse_feature(
batch_size, batch_size,
num_embed, num_embed,
@ -128,7 +141,91 @@ def test_freq_aware_embed():
f"model weight: {model_weight[10:18, :8]}, reference: {ref_weight[10:18, :8]}" f"model weight: {model_weight[10:18, :8]}, reference: {ref_weight[10:18, :8]}"
def gather_tensor(tensor, rank, world_size):
gather_list = []
if rank == 0:
gather_list = [torch.empty_like(tensor) for _ in range(world_size)]
torch.distributed.gather(tensor, gather_list, dst=0)
return gather_list
def run_parallel_freq_aware_embed(rank, world_size):
device = torch.device('cuda', torch.cuda.current_device())
num_embed = 100
embed_dim = 16
batch_size = 4
set_seed(4321)
weight = torch.rand(num_embed, embed_dim)
coloweight = ColoParameter(weight.clone().detach().cpu(), requires_grad=False)
model = ParallelFreqAwareEmbeddingBag.from_pretrained(coloweight,
include_last_offset=True,
freeze=False,
cuda_row_num=batch_size * 2)
assert model.cache_weight_mgr.cpu_weight.device.type == 'cpu'
assert model.cache_weight_mgr.cuda_cached_weight.requires_grad
weight_in_rank = torch.tensor_split(weight, world_size, -1)[rank]
assert torch.allclose(
weight_in_rank,
model.cache_weight_mgr.cpu_weight.detach()), f"{weight_in_rank - model.cache_weight_mgr.cpu_weight}"
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
if rank == 0:
ref_model = torch.nn.EmbeddingBag.from_pretrained(weight.detach().clone(),
include_last_offset=True,
freeze=False).to(device)
ref_optimizer = torch.optim.SGD(ref_model.parameters(), lr=1e-3)
set_seed(4321)
for i in range(5):
indices, offsets = synthesize_1d_sparse_feature(batch_size, num_embed, device)
res = model(indices, offsets)
grad = torch.rand(batch_size * 2, embed_dim, dtype=res.dtype, device=res.device)
grad_in_rank = torch.tensor_split(grad, world_size, 0)[rank]
res.backward(grad_in_rank)
optimizer.step()
optimizer.zero_grad()
res_list = gather_tensor(res.detach(), rank, world_size)
if rank == 0:
ref_res = ref_model(indices, offsets)
recover_res = torch.cat(res_list, dim=0)
assert torch.allclose(ref_res, recover_res)
ref_res.backward(grad)
ref_optimizer.step()
ref_optimizer.zero_grad()
model.cache_weight_mgr.flush()
weight_list = gather_tensor(model.cache_weight_mgr.cpu_weight.detach().cuda(), rank, world_size)
if rank == 0:
recover_weight = torch.cat(weight_list, dim=1)
assert torch.allclose(recover_weight, ref_model.weight.detach()), f"{recover_weight - ref_model.weight}"
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_parallel_freq_aware_embed(rank, world_size)
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use()
def test_parallel_freq_aware_embed(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
# test_freq_aware_embed() # test_freq_aware_embed()
# test_chunkmgr_admit() # test_chunkmgr_admit()
pass test_parallel_freq_aware_embed(2)