mirror of https://github.com/hpcaitech/ColossalAI
[embedding] tablewise sharding polish (#1535)
parent
56159049e8
commit
87134524fd
|
@ -1,5 +1,5 @@
|
||||||
import torch
|
import torch
|
||||||
from typing import Union, Optional
|
from typing import Union, Optional, List
|
||||||
from colossalai.tensor import ColoTensor
|
from colossalai.tensor import ColoTensor
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
@ -231,3 +231,54 @@ class _DualAllToAll(torch.autograd.Function):
|
||||||
|
|
||||||
def dual_all_to_all(x, pg, scatter_dim: int, gather_dim: int):
|
def dual_all_to_all(x, pg, scatter_dim: int, gather_dim: int):
|
||||||
return _DualAllToAll.apply(x, pg, scatter_dim, gather_dim)
|
return _DualAllToAll.apply(x, pg, scatter_dim, gather_dim)
|
||||||
|
|
||||||
|
|
||||||
|
### table wise embedding shard
|
||||||
|
|
||||||
|
|
||||||
|
def _all_to_all_for_tablewise(x: torch.Tensor,
|
||||||
|
pg: ProcessGroup,
|
||||||
|
scatter_strides: List[int],
|
||||||
|
gather_strides: List[int],
|
||||||
|
forward=True) -> torch.Tensor:
|
||||||
|
world_size = pg.tp_world_size()
|
||||||
|
rank = pg.tp_local_rank()
|
||||||
|
if world_size == 1:
|
||||||
|
return x
|
||||||
|
assert x.device.type == 'cuda', f"Currently, the collective function dual_all_to_all only supports nccl backend"
|
||||||
|
if forward:
|
||||||
|
scatter_list = list(x.split(scatter_strides, 0))
|
||||||
|
gather_list = [
|
||||||
|
torch.empty(scatter_strides[rank], gather_strides[i], dtype=x.dtype, device=x.device)
|
||||||
|
for i in range(world_size)
|
||||||
|
]
|
||||||
|
torch.distributed.all_to_all(gather_list, scatter_list, group=pg.tp_process_group())
|
||||||
|
return torch.cat(gather_list, 1).contiguous()
|
||||||
|
else:
|
||||||
|
# split on dim 1, lose contiguity
|
||||||
|
scatter_list = [each.contiguous() for each in x.split(scatter_strides, 1)]
|
||||||
|
gather_list = [
|
||||||
|
torch.empty(gather_strides[i], scatter_strides[rank], dtype=x.dtype, device=x.device)
|
||||||
|
for i in range(world_size)
|
||||||
|
]
|
||||||
|
torch.distributed.all_to_all(gather_list, scatter_list, group=pg.tp_process_group())
|
||||||
|
return torch.cat(gather_list, 0).contiguous()
|
||||||
|
|
||||||
|
|
||||||
|
class _DualAllToAllForTablewise(torch.autograd.Function):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, x, pg, scatter_strides, gather_strides):
|
||||||
|
ctx.pg = pg
|
||||||
|
ctx.scatter_strides = scatter_strides
|
||||||
|
ctx.gather_strides = gather_strides
|
||||||
|
return _all_to_all_for_tablewise(x, pg, scatter_strides, gather_strides, forward=True)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad):
|
||||||
|
return _all_to_all_for_tablewise(grad, ctx.pg, ctx.gather_strides, ctx.scatter_strides,
|
||||||
|
forward=False), None, None, None
|
||||||
|
|
||||||
|
|
||||||
|
def dual_all_to_all_tablewise(x, pg, scatter_strides, gather_strides):
|
||||||
|
return _DualAllToAllForTablewise.apply(x, pg, scatter_strides, gather_strides)
|
||||||
|
|
|
@ -1,14 +1,15 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from typing import List, Optional, Iterator, Tuple
|
from typing import List
|
||||||
import abc
|
import abc
|
||||||
|
|
||||||
from .freq_aware_embedding import FreqAwareEmbeddingBag
|
from .freq_aware_embedding import FreqAwareEmbeddingBag
|
||||||
|
|
||||||
from colossalai.tensor import ColoParameter, ShardSpec, ComputePattern, ProcessGroup, ColoTensorSpec, ColoTensor
|
from colossalai.tensor import ProcessGroup
|
||||||
from .cache_mgr import CachedParamMgr, EvictionStrategy
|
from .cache_mgr import EvictionStrategy
|
||||||
|
|
||||||
|
from colossalai.nn._ops._utils import dual_all_to_all_tablewise
|
||||||
|
|
||||||
|
|
||||||
class TablewiseEmbeddingBagConfig:
|
class TablewiseEmbeddingBagConfig:
|
||||||
|
@ -19,6 +20,7 @@ class TablewiseEmbeddingBagConfig:
|
||||||
...
|
...
|
||||||
return embedding_bag_config_list
|
return embedding_bag_config_list
|
||||||
'''
|
'''
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
num_embeddings: int,
|
num_embeddings: int,
|
||||||
cuda_row_num: int,
|
cuda_row_num: int,
|
||||||
|
@ -36,50 +38,10 @@ class TablewiseEmbeddingBagConfig:
|
||||||
self.name = name
|
self.name = name
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _all_to_all_for_tablewise(x: torch.Tensor, pg: ProcessGroup, scatter_strides: List[int], gather_strides: List[int], forward=True) -> torch.Tensor:
|
|
||||||
world_size = pg.tp_world_size()
|
|
||||||
rank = pg.tp_local_rank()
|
|
||||||
if world_size == 1:
|
|
||||||
return x
|
|
||||||
assert x.device.type == 'cuda', f"Currently, the collective function dual_all_to_all only supports nccl backend"
|
|
||||||
if forward:
|
|
||||||
scatter_list = list(x.split(scatter_strides, 0))
|
|
||||||
gather_list = [torch.empty(scatter_strides[rank], gather_strides[i], dtype=x.dtype,
|
|
||||||
device=x.device) for i in range(world_size)]
|
|
||||||
torch.distributed.all_to_all(gather_list, scatter_list, group=pg.tp_process_group())
|
|
||||||
return torch.cat(gather_list, 1).contiguous()
|
|
||||||
else:
|
|
||||||
# split on dim 1, lose contiguity
|
|
||||||
scatter_list = [each.contiguous() for each in x.split(scatter_strides, 1)]
|
|
||||||
gather_list = [torch.empty(gather_strides[i], scatter_strides[rank], dtype=x.dtype,
|
|
||||||
device=x.device) for i in range(world_size)]
|
|
||||||
torch.distributed.all_to_all(gather_list, scatter_list, group=pg.tp_process_group())
|
|
||||||
return torch.cat(gather_list, 0).contiguous()
|
|
||||||
|
|
||||||
|
|
||||||
class _DualAllToAllForTablewise(torch.autograd.Function):
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def forward(ctx, x, pg, scatter_strides, gather_strides):
|
|
||||||
ctx.pg = pg
|
|
||||||
ctx.scatter_strides = scatter_strides
|
|
||||||
ctx.gather_strides = gather_strides
|
|
||||||
return _all_to_all_for_tablewise(x, pg, scatter_strides, gather_strides, forward=True)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def backward(ctx, grad):
|
|
||||||
return _all_to_all_for_tablewise(grad, ctx.pg, ctx.gather_strides, ctx.scatter_strides, forward=False), None, None, None
|
|
||||||
|
|
||||||
|
|
||||||
def _dual_all_to_all(x, pg, scatter_strides, gather_strides):
|
|
||||||
return _DualAllToAllForTablewise.apply(x, pg, scatter_strides, gather_strides)
|
|
||||||
|
|
||||||
|
|
||||||
class ParallelFreqAwareEmbeddingBagTablewise(abc.ABC, nn.Module):
|
class ParallelFreqAwareEmbeddingBagTablewise(abc.ABC, nn.Module):
|
||||||
'''
|
"""
|
||||||
every table assigned to this class instance is managed by a FreqAwareEmbeddingBag.
|
every table assigned to this class instance is managed by a FreqAwareEmbeddingBag.
|
||||||
'''
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
embedding_bag_config_list: List[TablewiseEmbeddingBagConfig],
|
embedding_bag_config_list: List[TablewiseEmbeddingBagConfig],
|
||||||
|
@ -99,13 +61,13 @@ class ParallelFreqAwareEmbeddingBagTablewise(abc.ABC, nn.Module):
|
||||||
super(ParallelFreqAwareEmbeddingBagTablewise, self).__init__()
|
super(ParallelFreqAwareEmbeddingBagTablewise, self).__init__()
|
||||||
self.rank = dist.get_rank()
|
self.rank = dist.get_rank()
|
||||||
self.world_size = dist.get_world_size()
|
self.world_size = dist.get_world_size()
|
||||||
self.global_table_assign_list = [config.assigned_rank for config in embedding_bag_config_list]
|
self.rank_of_tables = [config.assigned_rank for config in embedding_bag_config_list]
|
||||||
self.global_table_num_embeddings_list = [config.num_embeddings for config in embedding_bag_config_list]
|
self.global_table_num_embeddings_list = [config.num_embeddings for config in embedding_bag_config_list]
|
||||||
self.global_tables_num = len(embedding_bag_config_list)
|
self.global_tables_num = len(embedding_bag_config_list)
|
||||||
self.global_tables_offsets = torch.cumsum(torch.tensor([0] + self.global_table_num_embeddings_list), 0)
|
self.global_tables_offsets = torch.cumsum(torch.tensor([0] + self.global_table_num_embeddings_list), 0)
|
||||||
|
|
||||||
self.assigned_table_list: List[int] = []
|
self.assigned_table_list: List[int] = []
|
||||||
for i, rank in enumerate(self.global_table_assign_list):
|
for i, rank in enumerate(self.rank_of_tables):
|
||||||
if rank == self.rank:
|
if rank == self.rank:
|
||||||
self.assigned_table_list.append(i)
|
self.assigned_table_list.append(i)
|
||||||
self.include_last_offset = include_last_offset
|
self.include_last_offset = include_last_offset
|
||||||
|
@ -118,37 +80,30 @@ class ParallelFreqAwareEmbeddingBagTablewise(abc.ABC, nn.Module):
|
||||||
if config.assigned_rank != self.rank:
|
if config.assigned_rank != self.rank:
|
||||||
continue
|
continue
|
||||||
self.freq_aware_embedding_bag_list.append(
|
self.freq_aware_embedding_bag_list.append(
|
||||||
FreqAwareEmbeddingBag(
|
FreqAwareEmbeddingBag(num_embeddings=config.num_embeddings,
|
||||||
num_embeddings=config.num_embeddings,
|
embedding_dim=embedding_dim,
|
||||||
embedding_dim=embedding_dim,
|
padding_idx=padding_idx,
|
||||||
padding_idx=padding_idx,
|
max_norm=max_norm,
|
||||||
max_norm=max_norm,
|
norm_type=norm_type,
|
||||||
norm_type=norm_type,
|
scale_grad_by_freq=scale_grad_by_freq,
|
||||||
scale_grad_by_freq=scale_grad_by_freq,
|
sparse=sparse,
|
||||||
sparse=sparse,
|
_weight=config.initial_weight,
|
||||||
_weight=config.initial_weight,
|
mode=mode,
|
||||||
mode=mode,
|
include_last_offset=include_last_offset,
|
||||||
include_last_offset=include_last_offset,
|
dtype=dtype,
|
||||||
dtype=dtype,
|
device=device,
|
||||||
device=device,
|
cuda_row_num=config.cuda_row_num,
|
||||||
cuda_row_num=config.cuda_row_num ,
|
ids_freq_mapping=config.ids_freq_mapping,
|
||||||
ids_freq_mapping=config.ids_freq_mapping,
|
warmup_ratio=warmup_ratio,
|
||||||
warmup_ratio=warmup_ratio,
|
buffer_size=config.buffer_size,
|
||||||
buffer_size=config.buffer_size,
|
pin_weight=pin_weight,
|
||||||
pin_weight=pin_weight,
|
evict_strategy=evict_strategy))
|
||||||
evict_strategy=evict_strategy
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# prepare list shape for all_to_all output
|
# prepare list shape for all_to_all output
|
||||||
self.embedding_dim_per_rank = [0 for i in range(self.world_size)]
|
self.embedding_dim_per_rank = [0 for i in range(self.world_size)]
|
||||||
for rank in self.global_table_assign_list:
|
for rank in self.rank_of_tables:
|
||||||
self.embedding_dim_per_rank[rank] += embedding_dim
|
self.embedding_dim_per_rank[rank] += embedding_dim
|
||||||
|
|
||||||
#print("global_table_assign_list {}".format(self.global_table_assign_list))
|
|
||||||
#print("global_table_num_embeddings_list {}".format(self.global_table_num_embeddings_list))
|
|
||||||
#print("global_tables_offsets {}".format(self.global_tables_offsets))
|
|
||||||
#
|
|
||||||
def forward(self, indices: torch.Tensor, offsets: torch.Tensor = None, per_sample_weights=None, shape_hook=None):
|
def forward(self, indices: torch.Tensor, offsets: torch.Tensor = None, per_sample_weights=None, shape_hook=None):
|
||||||
# determine indices to handle
|
# determine indices to handle
|
||||||
batch_size = (offsets.shape[0]) // self.global_tables_num
|
batch_size = (offsets.shape[0]) // self.global_tables_num
|
||||||
|
@ -158,27 +113,22 @@ class ParallelFreqAwareEmbeddingBagTablewise(abc.ABC, nn.Module):
|
||||||
if (not self.include_last_offset) and (batch_size * (handle_table + 1) >= indices.shape[0]):
|
if (not self.include_last_offset) and (batch_size * (handle_table + 1) >= indices.shape[0]):
|
||||||
# till the end special case
|
# till the end special case
|
||||||
indices_end_position = indices.shape[0]
|
indices_end_position = indices.shape[0]
|
||||||
else :
|
else:
|
||||||
indices_end_position = offsets[batch_size * (handle_table + 1)]
|
indices_end_position = offsets[batch_size * (handle_table + 1)]
|
||||||
|
|
||||||
local_indices = indices[indices_start_position:indices_end_position] - \
|
local_indices = indices[indices_start_position:indices_end_position] - \
|
||||||
self.global_tables_offsets[handle_table]
|
self.global_tables_offsets[handle_table]
|
||||||
if self.include_last_offset:
|
if self.include_last_offset:
|
||||||
local_offsets = offsets[batch_size * handle_table:batch_size
|
local_offsets = offsets[batch_size * handle_table:batch_size *
|
||||||
* (handle_table + 1) + 1] - offsets[batch_size * (handle_table)]
|
(handle_table + 1) + 1] - offsets[batch_size * (handle_table)]
|
||||||
else:
|
else:
|
||||||
local_offsets = offsets[batch_size * handle_table:batch_size
|
local_offsets = offsets[batch_size * handle_table:batch_size *
|
||||||
* (handle_table + 1)] - offsets[batch_size * (handle_table)]
|
(handle_table + 1)] - offsets[batch_size * (handle_table)]
|
||||||
local_per_sample_weights = None
|
local_per_sample_weights = None
|
||||||
if per_sample_weights != None:
|
if per_sample_weights != None:
|
||||||
local_per_sample_weights = per_sample_weights[indices_start_position:indices_end_position]
|
local_per_sample_weights = per_sample_weights[indices_start_position:indices_end_position]
|
||||||
local_output_list.append(
|
local_output_list.append(self.freq_aware_embedding_bag_list[i](local_indices, local_offsets,
|
||||||
self.freq_aware_embedding_bag_list[i](
|
local_per_sample_weights))
|
||||||
local_indices,
|
|
||||||
local_offsets,
|
|
||||||
local_per_sample_weights
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# get result of shape = (batch_size, (len(assigned_table_list)*embedding_dim))
|
# get result of shape = (batch_size, (len(assigned_table_list)*embedding_dim))
|
||||||
local_output = torch.cat(local_output_list, 1)
|
local_output = torch.cat(local_output_list, 1)
|
||||||
|
@ -186,7 +136,7 @@ class ParallelFreqAwareEmbeddingBagTablewise(abc.ABC, nn.Module):
|
||||||
# use all_to_all
|
# use all_to_all
|
||||||
remains = batch_size % self.world_size
|
remains = batch_size % self.world_size
|
||||||
scatter_strides = [batch_size // self.world_size + int(i < remains) for i in range(self.world_size)]
|
scatter_strides = [batch_size // self.world_size + int(i < remains) for i in range(self.world_size)]
|
||||||
output_full = _dual_all_to_all(local_output, self.pg, scatter_strides, self.embedding_dim_per_rank)
|
output_full = dual_all_to_all_tablewise(local_output, self.pg, scatter_strides, self.embedding_dim_per_rank)
|
||||||
if shape_hook is not None:
|
if shape_hook is not None:
|
||||||
output_full = shape_hook(output_full)
|
output_full = shape_hook(output_full)
|
||||||
return output_full
|
return output_full
|
||||||
|
|
Loading…
Reference in New Issue