mirror of https://github.com/hpcaitech/ColossalAI
[embeddings] add already_split_along_rank flag for tablewise mode (#1584)
parent
77399dc91b
commit
f3403ff98e
|
@ -9,6 +9,7 @@ from colossalai.tensor import ProcessGroup
|
|||
from colossalai.nn._ops._utils import dual_all_to_all_tablewise
|
||||
|
||||
from typing import List
|
||||
import time
|
||||
|
||||
|
||||
class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
|
||||
|
@ -79,8 +80,43 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
|
|||
for rank in self.rank_of_tables:
|
||||
self.embedding_dim_per_rank[rank] += embedding_dim
|
||||
|
||||
def forward(self, indices: torch.Tensor, offsets: torch.Tensor = None, per_sample_weights=None, shape_hook=None):
|
||||
batch_size = (offsets.shape[0]) // self.global_tables_num
|
||||
def forward(self,
|
||||
indices: torch.Tensor,
|
||||
offsets: torch.Tensor = None,
|
||||
per_sample_weights=None,
|
||||
shape_hook=None,
|
||||
already_split_along_rank=True):
|
||||
if not already_split_along_rank:
|
||||
# not recommanded. it takes time.
|
||||
batch_size = (offsets.shape[0]) // self.global_tables_num
|
||||
local_indices, local_offsets, local_per_sample_weights = self.split_along_rank(
|
||||
batch_size, indices, offsets, per_sample_weights)
|
||||
else:
|
||||
# recommanded.
|
||||
batch_size = (offsets.shape[0]) // len(self.assigned_table_list)
|
||||
local_indices, local_offsets, local_per_sample_weights = indices, offsets, per_sample_weights
|
||||
with torch.no_grad():
|
||||
reorder_ids = self.cache_weight_mgr.prepare_ids(local_indices)
|
||||
local_output = F.embedding_bag(reorder_ids.cuda(), self.cache_weight_mgr.cuda_cached_weight, local_offsets,
|
||||
self.max_norm, self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse,
|
||||
local_per_sample_weights, self.include_last_offset, self.padding_idx)
|
||||
local_output = torch.cat(local_output.split(batch_size), 1)
|
||||
remains = batch_size % self.world_size
|
||||
scatter_strides = [batch_size // self.world_size + int(i < remains) for i in range(self.world_size)]
|
||||
output_full = dual_all_to_all_tablewise(local_output, self.pg, scatter_strides, self.embedding_dim_per_rank)
|
||||
if shape_hook is not None:
|
||||
output_full = shape_hook(output_full)
|
||||
return output_full
|
||||
|
||||
def split_along_rank(self,
|
||||
batch_size,
|
||||
indices: torch.Tensor,
|
||||
offsets: torch.Tensor = None,
|
||||
per_sample_weights=None):
|
||||
'''
|
||||
if input indices and offsets haven't been splitted along assigned rank, this function will do it.
|
||||
it takes time. please consider splitting data during batch loading.
|
||||
'''
|
||||
local_indices_list: List(torch.Tensor) = []
|
||||
local_offsets_list: List(torch.Tensor) = []
|
||||
if per_sample_weights != None:
|
||||
|
@ -145,20 +181,7 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
|
|||
local_per_sample_weights = None
|
||||
if per_sample_weights != None:
|
||||
local_per_sample_weights = torch.cat(local_per_sample_weights_list, 0)
|
||||
with torch.no_grad():
|
||||
reorder_ids = self.cache_weight_mgr.prepare_ids(local_indices)
|
||||
|
||||
local_output = F.embedding_bag(reorder_ids.cuda(), self.cache_weight_mgr.cuda_cached_weight, local_offsets,
|
||||
self.max_norm, self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse,
|
||||
local_per_sample_weights, self.include_last_offset, self.padding_idx)
|
||||
local_output = torch.cat(local_output.split(batch_size), 1)
|
||||
|
||||
remains = batch_size % self.world_size
|
||||
scatter_strides = [batch_size // self.world_size + int(i < remains) for i in range(self.world_size)]
|
||||
output_full = dual_all_to_all_tablewise(local_output, self.pg, scatter_strides, self.embedding_dim_per_rank)
|
||||
if shape_hook is not None:
|
||||
output_full = shape_hook(output_full)
|
||||
return output_full
|
||||
return local_indices, local_offsets, local_per_sample_weights
|
||||
|
||||
def print_comm_stats_(self):
|
||||
self.cache_weight_mgr.print_comm_stats()
|
||||
|
|
|
@ -253,7 +253,8 @@ def run_parallel_freq_aware_embed_tablewise(rank, world_size):
|
|||
in KJT format
|
||||
'''
|
||||
res = model(torch.tensor([1, 2, 3, 1, 5, 6, 7, 9, 6, 8, 13, 15, 11], device=device),
|
||||
torch.tensor([0, 3, 3, 5, 7, 8, 10, 10, 12, 13], device=device))
|
||||
torch.tensor([0, 3, 3, 5, 7, 8, 10, 10, 12, 13], device=device),
|
||||
already_split_along_rank=False)
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)
|
||||
rand_grad = torch.rand(3, 5 * 3, dtype=res.dtype, device=res.device)
|
||||
if rank == 0:
|
||||
|
|
Loading…
Reference in New Issue