mirror of https://github.com/hpcaitech/ColossalAI
142 lines
5.8 KiB
Python
142 lines
5.8 KiB
Python
import torch
|
|
import torch.nn.functional as F
|
|
from typing import List, Optional, Iterator, Tuple
|
|
|
|
from .cached_embedding import CachedEmbeddingBag
|
|
from colossalai.nn._ops._utils import dual_all_to_all
|
|
|
|
from colossalai.tensor import ColoParameter, ShardSpec, ComputePattern, ProcessGroup, ColoTensorSpec, ColoTensor
|
|
from .cache_mgr import CachedParamMgr, EvictionStrategy
|
|
|
|
|
|
def get_partition(embedding_dim, rank, world_size) -> Tuple[int, int, bool]:
|
|
if world_size == 1:
|
|
return 0, embedding_dim, True
|
|
|
|
assert embedding_dim >= world_size, \
|
|
f"Embedding dimension {embedding_dim} must be larger than the world size " \
|
|
f"{world_size} of the process group"
|
|
chunk_size = embedding_dim // world_size
|
|
threshold = embedding_dim % world_size
|
|
# if embedding dim is divisible by world size
|
|
if threshold == 0:
|
|
return rank * chunk_size, (rank + 1) * chunk_size, True
|
|
|
|
# align with the split strategy of torch.tensor_split
|
|
size_list = [chunk_size + 1 if i < threshold else chunk_size for i in range(world_size)]
|
|
offset = sum(size_list[:rank])
|
|
return offset, offset + size_list[rank], False
|
|
|
|
|
|
class ParallelCachedEmbeddingBag(CachedEmbeddingBag):
|
|
|
|
def __init__(self,
|
|
num_embeddings,
|
|
embedding_dim,
|
|
padding_idx=None,
|
|
max_norm=None,
|
|
norm_type=2.,
|
|
scale_grad_by_freq=False,
|
|
sparse=False,
|
|
_weight=None,
|
|
mode='mean',
|
|
include_last_offset=False,
|
|
dtype=None,
|
|
device=None,
|
|
cache_ratio=0.01,
|
|
ids_freq_mapping=None,
|
|
warmup_ratio=0.7,
|
|
buffer_size=50_000,
|
|
pin_weight=False,
|
|
evict_strategy: EvictionStrategy = EvictionStrategy.DATASET):
|
|
self.rank = torch.distributed.get_rank()
|
|
self.world_size = torch.distributed.get_world_size()
|
|
|
|
self.partition_start_index, self.partition_end_index, divisible = get_partition(
|
|
embedding_dim, self.rank, self.world_size)
|
|
self.embedding_dim_per_partition = self.partition_end_index - self.partition_start_index
|
|
|
|
super(ParallelCachedEmbeddingBag,
|
|
self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq,
|
|
sparse, _weight, mode, include_last_offset, dtype, device, cache_ratio, ids_freq_mapping,
|
|
warmup_ratio, buffer_size, pin_weight, evict_strategy)
|
|
self.cache_op = True
|
|
|
|
def _weight_alloc(self, dtype, device):
|
|
weight = torch.empty(self.num_embeddings, self.embedding_dim_per_partition, device=device, dtype=dtype)
|
|
with torch.no_grad():
|
|
weight.data.uniform_(-1 / self.num_embeddings, 1 / self.num_embeddings)
|
|
if self.padding_idx is not None:
|
|
weight[self.padding_idx].fill_(0)
|
|
colo_tensor_spec = ColoTensorSpec(pg=ProcessGroup(tp_degree=self.world_size),
|
|
dist_attr=ShardSpec(dims=[-1], num_partitions=[self.world_size]),
|
|
compute_attr=ComputePattern.TP1D)
|
|
return ColoTensor.from_torch_tensor(weight, spec=colo_tensor_spec)
|
|
|
|
def forward(
|
|
self,
|
|
indices,
|
|
offsets=None,
|
|
per_sample_weights=None,
|
|
shape_hook=None,
|
|
scatter_dim=0,
|
|
gather_dim=-1,
|
|
):
|
|
if self.cache_op:
|
|
with torch.no_grad():
|
|
indices = self.cache_weight_mgr.prepare_ids(indices)
|
|
output_shard = F.embedding_bag(indices.cuda(), self.cache_weight_mgr.cuda_cached_weight, offsets, self.max_norm,
|
|
self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse,
|
|
per_sample_weights, self.include_last_offset, self.padding_idx)
|
|
if shape_hook is not None:
|
|
output_shard = shape_hook(output_shard)
|
|
output_full = dual_all_to_all(output_shard,
|
|
self.weight.get_process_group(),
|
|
scatter_dim=scatter_dim,
|
|
gather_dim=gather_dim)
|
|
return output_full
|
|
|
|
def set_cache_op(self, cache_op: bool = True):
|
|
self.cache_op = cache_op
|
|
|
|
@classmethod
|
|
def from_pretrained(
|
|
cls,
|
|
embedding: torch.Tensor,
|
|
freeze: bool = True,
|
|
padding_idx: Optional[int] = None,
|
|
max_norm: Optional[float] = None,
|
|
norm_type: float = 2.,
|
|
scale_grad_by_freq: bool = False,
|
|
sparse: bool = False,
|
|
mode: str = 'mean',
|
|
include_last_offset: bool = False,
|
|
cuda_row_num: int = 100_000,
|
|
ids_freq_mapping: Optional[List[int]] = None,
|
|
warmup_ratio: float = 0.7,
|
|
buffer_size: int = 0,
|
|
) -> 'ParallelCachedEmbeddingBag':
|
|
rows, cols = embedding.shape
|
|
embedding_bag = cls(rows,
|
|
cols,
|
|
padding_idx,
|
|
max_norm,
|
|
norm_type,
|
|
scale_grad_by_freq,
|
|
sparse,
|
|
embedding,
|
|
mode,
|
|
include_last_offset,
|
|
cuda_row_num=cuda_row_num,
|
|
ids_freq_mapping=ids_freq_mapping,
|
|
warmup_ratio=warmup_ratio,
|
|
buffer_size=buffer_size)
|
|
embedding_bag.cache_weight_mgr.cuda_cached_weight.requires_grad_ = not freeze
|
|
return embedding_bag
|
|
|
|
def print_comm_stats_(self):
|
|
self.cache_weight_mgr.print_comm_stats()
|
|
|
|
def element_size(self):
|
|
return self.weight.element_size()
|