[embedding] tablewise sharding polish (#1535)

pull/1538/head^2
Jiarui Fang 2022-09-02 11:09:37 +08:00 committed by GitHub
parent 56159049e8
commit 87134524fd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 90 additions and 89 deletions

View File

@ -1,5 +1,5 @@
import torch import torch
from typing import Union, Optional from typing import Union, Optional, List
from colossalai.tensor import ColoTensor from colossalai.tensor import ColoTensor
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -231,3 +231,54 @@ class _DualAllToAll(torch.autograd.Function):
def dual_all_to_all(x, pg, scatter_dim: int, gather_dim: int): def dual_all_to_all(x, pg, scatter_dim: int, gather_dim: int):
return _DualAllToAll.apply(x, pg, scatter_dim, gather_dim) return _DualAllToAll.apply(x, pg, scatter_dim, gather_dim)
### table wise embedding shard
def _all_to_all_for_tablewise(x: torch.Tensor,
pg: ProcessGroup,
scatter_strides: List[int],
gather_strides: List[int],
forward=True) -> torch.Tensor:
world_size = pg.tp_world_size()
rank = pg.tp_local_rank()
if world_size == 1:
return x
assert x.device.type == 'cuda', f"Currently, the collective function dual_all_to_all only supports nccl backend"
if forward:
scatter_list = list(x.split(scatter_strides, 0))
gather_list = [
torch.empty(scatter_strides[rank], gather_strides[i], dtype=x.dtype, device=x.device)
for i in range(world_size)
]
torch.distributed.all_to_all(gather_list, scatter_list, group=pg.tp_process_group())
return torch.cat(gather_list, 1).contiguous()
else:
# split on dim 1, lose contiguity
scatter_list = [each.contiguous() for each in x.split(scatter_strides, 1)]
gather_list = [
torch.empty(gather_strides[i], scatter_strides[rank], dtype=x.dtype, device=x.device)
for i in range(world_size)
]
torch.distributed.all_to_all(gather_list, scatter_list, group=pg.tp_process_group())
return torch.cat(gather_list, 0).contiguous()
class _DualAllToAllForTablewise(torch.autograd.Function):
@staticmethod
def forward(ctx, x, pg, scatter_strides, gather_strides):
ctx.pg = pg
ctx.scatter_strides = scatter_strides
ctx.gather_strides = gather_strides
return _all_to_all_for_tablewise(x, pg, scatter_strides, gather_strides, forward=True)
@staticmethod
def backward(ctx, grad):
return _all_to_all_for_tablewise(grad, ctx.pg, ctx.gather_strides, ctx.scatter_strides,
forward=False), None, None, None
def dual_all_to_all_tablewise(x, pg, scatter_strides, gather_strides):
return _DualAllToAllForTablewise.apply(x, pg, scatter_strides, gather_strides)

View File

@ -1,14 +1,15 @@
import torch import torch
import torch.nn.functional as F
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from typing import List, Optional, Iterator, Tuple from typing import List
import abc import abc
from .freq_aware_embedding import FreqAwareEmbeddingBag from .freq_aware_embedding import FreqAwareEmbeddingBag
from colossalai.tensor import ColoParameter, ShardSpec, ComputePattern, ProcessGroup, ColoTensorSpec, ColoTensor from colossalai.tensor import ProcessGroup
from .cache_mgr import CachedParamMgr, EvictionStrategy from .cache_mgr import EvictionStrategy
from colossalai.nn._ops._utils import dual_all_to_all_tablewise
class TablewiseEmbeddingBagConfig: class TablewiseEmbeddingBagConfig:
@ -19,6 +20,7 @@ class TablewiseEmbeddingBagConfig:
... ...
return embedding_bag_config_list return embedding_bag_config_list
''' '''
def __init__(self, def __init__(self,
num_embeddings: int, num_embeddings: int,
cuda_row_num: int, cuda_row_num: int,
@ -36,50 +38,10 @@ class TablewiseEmbeddingBagConfig:
self.name = name self.name = name
def _all_to_all_for_tablewise(x: torch.Tensor, pg: ProcessGroup, scatter_strides: List[int], gather_strides: List[int], forward=True) -> torch.Tensor:
world_size = pg.tp_world_size()
rank = pg.tp_local_rank()
if world_size == 1:
return x
assert x.device.type == 'cuda', f"Currently, the collective function dual_all_to_all only supports nccl backend"
if forward:
scatter_list = list(x.split(scatter_strides, 0))
gather_list = [torch.empty(scatter_strides[rank], gather_strides[i], dtype=x.dtype,
device=x.device) for i in range(world_size)]
torch.distributed.all_to_all(gather_list, scatter_list, group=pg.tp_process_group())
return torch.cat(gather_list, 1).contiguous()
else:
# split on dim 1, lose contiguity
scatter_list = [each.contiguous() for each in x.split(scatter_strides, 1)]
gather_list = [torch.empty(gather_strides[i], scatter_strides[rank], dtype=x.dtype,
device=x.device) for i in range(world_size)]
torch.distributed.all_to_all(gather_list, scatter_list, group=pg.tp_process_group())
return torch.cat(gather_list, 0).contiguous()
class _DualAllToAllForTablewise(torch.autograd.Function):
@staticmethod
def forward(ctx, x, pg, scatter_strides, gather_strides):
ctx.pg = pg
ctx.scatter_strides = scatter_strides
ctx.gather_strides = gather_strides
return _all_to_all_for_tablewise(x, pg, scatter_strides, gather_strides, forward=True)
@staticmethod
def backward(ctx, grad):
return _all_to_all_for_tablewise(grad, ctx.pg, ctx.gather_strides, ctx.scatter_strides, forward=False), None, None, None
def _dual_all_to_all(x, pg, scatter_strides, gather_strides):
return _DualAllToAllForTablewise.apply(x, pg, scatter_strides, gather_strides)
class ParallelFreqAwareEmbeddingBagTablewise(abc.ABC, nn.Module): class ParallelFreqAwareEmbeddingBagTablewise(abc.ABC, nn.Module):
''' """
every table assigned to this class instance is managed by a FreqAwareEmbeddingBag. every table assigned to this class instance is managed by a FreqAwareEmbeddingBag.
''' """
def __init__(self, def __init__(self,
embedding_bag_config_list: List[TablewiseEmbeddingBagConfig], embedding_bag_config_list: List[TablewiseEmbeddingBagConfig],
@ -99,13 +61,13 @@ class ParallelFreqAwareEmbeddingBagTablewise(abc.ABC, nn.Module):
super(ParallelFreqAwareEmbeddingBagTablewise, self).__init__() super(ParallelFreqAwareEmbeddingBagTablewise, self).__init__()
self.rank = dist.get_rank() self.rank = dist.get_rank()
self.world_size = dist.get_world_size() self.world_size = dist.get_world_size()
self.global_table_assign_list = [config.assigned_rank for config in embedding_bag_config_list] self.rank_of_tables = [config.assigned_rank for config in embedding_bag_config_list]
self.global_table_num_embeddings_list = [config.num_embeddings for config in embedding_bag_config_list] self.global_table_num_embeddings_list = [config.num_embeddings for config in embedding_bag_config_list]
self.global_tables_num = len(embedding_bag_config_list) self.global_tables_num = len(embedding_bag_config_list)
self.global_tables_offsets = torch.cumsum(torch.tensor([0] + self.global_table_num_embeddings_list), 0) self.global_tables_offsets = torch.cumsum(torch.tensor([0] + self.global_table_num_embeddings_list), 0)
self.assigned_table_list: List[int] = [] self.assigned_table_list: List[int] = []
for i, rank in enumerate(self.global_table_assign_list): for i, rank in enumerate(self.rank_of_tables):
if rank == self.rank: if rank == self.rank:
self.assigned_table_list.append(i) self.assigned_table_list.append(i)
self.include_last_offset = include_last_offset self.include_last_offset = include_last_offset
@ -118,37 +80,30 @@ class ParallelFreqAwareEmbeddingBagTablewise(abc.ABC, nn.Module):
if config.assigned_rank != self.rank: if config.assigned_rank != self.rank:
continue continue
self.freq_aware_embedding_bag_list.append( self.freq_aware_embedding_bag_list.append(
FreqAwareEmbeddingBag( FreqAwareEmbeddingBag(num_embeddings=config.num_embeddings,
num_embeddings=config.num_embeddings, embedding_dim=embedding_dim,
embedding_dim=embedding_dim, padding_idx=padding_idx,
padding_idx=padding_idx, max_norm=max_norm,
max_norm=max_norm, norm_type=norm_type,
norm_type=norm_type, scale_grad_by_freq=scale_grad_by_freq,
scale_grad_by_freq=scale_grad_by_freq, sparse=sparse,
sparse=sparse, _weight=config.initial_weight,
_weight=config.initial_weight, mode=mode,
mode=mode, include_last_offset=include_last_offset,
include_last_offset=include_last_offset, dtype=dtype,
dtype=dtype, device=device,
device=device, cuda_row_num=config.cuda_row_num,
cuda_row_num=config.cuda_row_num , ids_freq_mapping=config.ids_freq_mapping,
ids_freq_mapping=config.ids_freq_mapping, warmup_ratio=warmup_ratio,
warmup_ratio=warmup_ratio, buffer_size=config.buffer_size,
buffer_size=config.buffer_size, pin_weight=pin_weight,
pin_weight=pin_weight, evict_strategy=evict_strategy))
evict_strategy=evict_strategy
)
)
# prepare list shape for all_to_all output # prepare list shape for all_to_all output
self.embedding_dim_per_rank = [0 for i in range(self.world_size)] self.embedding_dim_per_rank = [0 for i in range(self.world_size)]
for rank in self.global_table_assign_list: for rank in self.rank_of_tables:
self.embedding_dim_per_rank[rank] += embedding_dim self.embedding_dim_per_rank[rank] += embedding_dim
#print("global_table_assign_list {}".format(self.global_table_assign_list))
#print("global_table_num_embeddings_list {}".format(self.global_table_num_embeddings_list))
#print("global_tables_offsets {}".format(self.global_tables_offsets))
#
def forward(self, indices: torch.Tensor, offsets: torch.Tensor = None, per_sample_weights=None, shape_hook=None): def forward(self, indices: torch.Tensor, offsets: torch.Tensor = None, per_sample_weights=None, shape_hook=None):
# determine indices to handle # determine indices to handle
batch_size = (offsets.shape[0]) // self.global_tables_num batch_size = (offsets.shape[0]) // self.global_tables_num
@ -158,27 +113,22 @@ class ParallelFreqAwareEmbeddingBagTablewise(abc.ABC, nn.Module):
if (not self.include_last_offset) and (batch_size * (handle_table + 1) >= indices.shape[0]): if (not self.include_last_offset) and (batch_size * (handle_table + 1) >= indices.shape[0]):
# till the end special case # till the end special case
indices_end_position = indices.shape[0] indices_end_position = indices.shape[0]
else : else:
indices_end_position = offsets[batch_size * (handle_table + 1)] indices_end_position = offsets[batch_size * (handle_table + 1)]
local_indices = indices[indices_start_position:indices_end_position] - \ local_indices = indices[indices_start_position:indices_end_position] - \
self.global_tables_offsets[handle_table] self.global_tables_offsets[handle_table]
if self.include_last_offset: if self.include_last_offset:
local_offsets = offsets[batch_size * handle_table:batch_size local_offsets = offsets[batch_size * handle_table:batch_size *
* (handle_table + 1) + 1] - offsets[batch_size * (handle_table)] (handle_table + 1) + 1] - offsets[batch_size * (handle_table)]
else: else:
local_offsets = offsets[batch_size * handle_table:batch_size local_offsets = offsets[batch_size * handle_table:batch_size *
* (handle_table + 1)] - offsets[batch_size * (handle_table)] (handle_table + 1)] - offsets[batch_size * (handle_table)]
local_per_sample_weights = None local_per_sample_weights = None
if per_sample_weights != None: if per_sample_weights != None:
local_per_sample_weights = per_sample_weights[indices_start_position:indices_end_position] local_per_sample_weights = per_sample_weights[indices_start_position:indices_end_position]
local_output_list.append( local_output_list.append(self.freq_aware_embedding_bag_list[i](local_indices, local_offsets,
self.freq_aware_embedding_bag_list[i]( local_per_sample_weights))
local_indices,
local_offsets,
local_per_sample_weights
)
)
# get result of shape = (batch_size, (len(assigned_table_list)*embedding_dim)) # get result of shape = (batch_size, (len(assigned_table_list)*embedding_dim))
local_output = torch.cat(local_output_list, 1) local_output = torch.cat(local_output_list, 1)
@ -186,7 +136,7 @@ class ParallelFreqAwareEmbeddingBagTablewise(abc.ABC, nn.Module):
# use all_to_all # use all_to_all
remains = batch_size % self.world_size remains = batch_size % self.world_size
scatter_strides = [batch_size // self.world_size + int(i < remains) for i in range(self.world_size)] scatter_strides = [batch_size // self.world_size + int(i < remains) for i in range(self.world_size)]
output_full = _dual_all_to_all(local_output, self.pg, scatter_strides, self.embedding_dim_per_rank) output_full = dual_all_to_all_tablewise(local_output, self.pg, scatter_strides, self.embedding_dim_per_rank)
if shape_hook is not None: if shape_hook is not None:
output_full = shape_hook(output_full) output_full = shape_hook(output_full)
return output_full return output_full