ColossalAI/colossalai/nn/parallel/layers/cache_embedding/parallel_cached_embedding.py

142 lines
5.8 KiB
Python
Raw Normal View History

import torch
import torch.nn.functional as F
from typing import List, Optional, Iterator, Tuple
from .cached_embedding import CachedEmbeddingBag
2022-08-11 05:43:24 +00:00
from colossalai.nn._ops._utils import dual_all_to_all
from colossalai.tensor import ColoParameter, ShardSpec, ComputePattern, ProcessGroup, ColoTensorSpec, ColoTensor
from .cache_mgr import CachedParamMgr, EvictionStrategy
2022-08-30 06:50:02 +00:00
def get_partition(embedding_dim, rank, world_size) -> Tuple[int, int, bool]:
if world_size == 1:
return 0, embedding_dim, True
assert embedding_dim >= world_size, \
f"Embedding dimension {embedding_dim} must be larger than the world size " \
f"{world_size} of the process group"
chunk_size = embedding_dim // world_size
threshold = embedding_dim % world_size
# if embedding dim is divisible by world size
if threshold == 0:
return rank * chunk_size, (rank + 1) * chunk_size, True
# align with the split strategy of torch.tensor_split
size_list = [chunk_size + 1 if i < threshold else chunk_size for i in range(world_size)]
offset = sum(size_list[:rank])
return offset, offset + size_list[rank], False
class ParallelCachedEmbeddingBag(CachedEmbeddingBag):
2022-08-30 06:50:02 +00:00
def __init__(self,
num_embeddings,
embedding_dim,
padding_idx=None,
max_norm=None,
norm_type=2.,
scale_grad_by_freq=False,
sparse=False,
_weight=None,
mode='mean',
include_last_offset=False,
dtype=None,
device=None,
cache_ratio=0.01,
2022-08-30 06:50:02 +00:00
ids_freq_mapping=None,
warmup_ratio=0.7,
buffer_size=50_000,
pin_weight=False,
evict_strategy: EvictionStrategy = EvictionStrategy.DATASET):
self.rank = torch.distributed.get_rank()
self.world_size = torch.distributed.get_world_size()
self.partition_start_index, self.partition_end_index, divisible = get_partition(
embedding_dim, self.rank, self.world_size)
self.embedding_dim_per_partition = self.partition_end_index - self.partition_start_index
super(ParallelCachedEmbeddingBag,
self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq,
sparse, _weight, mode, include_last_offset, dtype, device, cache_ratio, ids_freq_mapping,
2022-08-30 06:50:02 +00:00
warmup_ratio, buffer_size, pin_weight, evict_strategy)
self.cache_op = True
def _weight_alloc(self, dtype, device):
weight = torch.empty(self.num_embeddings, self.embedding_dim_per_partition, device=device, dtype=dtype)
with torch.no_grad():
weight.data.uniform_(-1 / self.num_embeddings, 1 / self.num_embeddings)
if self.padding_idx is not None:
weight[self.padding_idx].fill_(0)
colo_tensor_spec = ColoTensorSpec(pg=ProcessGroup(tp_degree=self.world_size),
dist_attr=ShardSpec(dims=[-1], num_partitions=[self.world_size]),
compute_attr=ComputePattern.TP1D)
return ColoTensor.from_torch_tensor(weight, spec=colo_tensor_spec)
def forward(
self,
indices,
offsets=None,
per_sample_weights=None,
shape_hook=None,
scatter_dim=0,
gather_dim=-1,
):
if self.cache_op:
2022-09-23 08:40:18 +00:00
with torch.no_grad():
indices = self.cache_weight_mgr.prepare_ids(indices)
output_shard = F.embedding_bag(indices.cuda(), self.cache_weight_mgr.cuda_cached_weight, offsets, self.max_norm,
self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse,
per_sample_weights, self.include_last_offset, self.padding_idx)
if shape_hook is not None:
output_shard = shape_hook(output_shard)
output_full = dual_all_to_all(output_shard,
self.weight.get_process_group(),
scatter_dim=scatter_dim,
gather_dim=gather_dim)
return output_full
def set_cache_op(self, cache_op: bool = True):
self.cache_op = cache_op
@classmethod
def from_pretrained(
cls,
embedding: torch.Tensor,
freeze: bool = True,
padding_idx: Optional[int] = None,
max_norm: Optional[float] = None,
norm_type: float = 2.,
scale_grad_by_freq: bool = False,
sparse: bool = False,
mode: str = 'mean',
include_last_offset: bool = False,
cuda_row_num: int = 100_000,
ids_freq_mapping: Optional[List[int]] = None,
warmup_ratio: float = 0.7,
2022-09-27 06:37:03 +00:00
buffer_size: int = 0,
) -> 'ParallelCachedEmbeddingBag':
rows, cols = embedding.shape
embedding_bag = cls(rows,
cols,
padding_idx,
max_norm,
norm_type,
scale_grad_by_freq,
sparse,
embedding,
mode,
include_last_offset,
cuda_row_num=cuda_row_num,
ids_freq_mapping=ids_freq_mapping,
warmup_ratio=warmup_ratio,
buffer_size=buffer_size)
embedding_bag.cache_weight_mgr.cuda_cached_weight.requires_grad_ = not freeze
return embedding_bag
def print_comm_stats_(self):
self.cache_weight_mgr.print_comm_stats()
def element_size(self):
return self.weight.element_size()