[embeddings] add already_split_along_rank flag for tablewise mode (#1584)

pull/1588/head
CsRic 2022-09-13 10:50:34 +08:00 committed by GitHub
parent 77399dc91b
commit f3403ff98e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 41 additions and 17 deletions

View File

@ -9,6 +9,7 @@ from colossalai.tensor import ProcessGroup
from colossalai.nn._ops._utils import dual_all_to_all_tablewise from colossalai.nn._ops._utils import dual_all_to_all_tablewise
from typing import List from typing import List
import time
class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag): class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
@ -79,8 +80,43 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
for rank in self.rank_of_tables: for rank in self.rank_of_tables:
self.embedding_dim_per_rank[rank] += embedding_dim self.embedding_dim_per_rank[rank] += embedding_dim
def forward(self, indices: torch.Tensor, offsets: torch.Tensor = None, per_sample_weights=None, shape_hook=None): def forward(self,
indices: torch.Tensor,
offsets: torch.Tensor = None,
per_sample_weights=None,
shape_hook=None,
already_split_along_rank=True):
if not already_split_along_rank:
# not recommanded. it takes time.
batch_size = (offsets.shape[0]) // self.global_tables_num batch_size = (offsets.shape[0]) // self.global_tables_num
local_indices, local_offsets, local_per_sample_weights = self.split_along_rank(
batch_size, indices, offsets, per_sample_weights)
else:
# recommanded.
batch_size = (offsets.shape[0]) // len(self.assigned_table_list)
local_indices, local_offsets, local_per_sample_weights = indices, offsets, per_sample_weights
with torch.no_grad():
reorder_ids = self.cache_weight_mgr.prepare_ids(local_indices)
local_output = F.embedding_bag(reorder_ids.cuda(), self.cache_weight_mgr.cuda_cached_weight, local_offsets,
self.max_norm, self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse,
local_per_sample_weights, self.include_last_offset, self.padding_idx)
local_output = torch.cat(local_output.split(batch_size), 1)
remains = batch_size % self.world_size
scatter_strides = [batch_size // self.world_size + int(i < remains) for i in range(self.world_size)]
output_full = dual_all_to_all_tablewise(local_output, self.pg, scatter_strides, self.embedding_dim_per_rank)
if shape_hook is not None:
output_full = shape_hook(output_full)
return output_full
def split_along_rank(self,
batch_size,
indices: torch.Tensor,
offsets: torch.Tensor = None,
per_sample_weights=None):
'''
if input indices and offsets haven't been splitted along assigned rank, this function will do it.
it takes time. please consider splitting data during batch loading.
'''
local_indices_list: List(torch.Tensor) = [] local_indices_list: List(torch.Tensor) = []
local_offsets_list: List(torch.Tensor) = [] local_offsets_list: List(torch.Tensor) = []
if per_sample_weights != None: if per_sample_weights != None:
@ -145,20 +181,7 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
local_per_sample_weights = None local_per_sample_weights = None
if per_sample_weights != None: if per_sample_weights != None:
local_per_sample_weights = torch.cat(local_per_sample_weights_list, 0) local_per_sample_weights = torch.cat(local_per_sample_weights_list, 0)
with torch.no_grad(): return local_indices, local_offsets, local_per_sample_weights
reorder_ids = self.cache_weight_mgr.prepare_ids(local_indices)
local_output = F.embedding_bag(reorder_ids.cuda(), self.cache_weight_mgr.cuda_cached_weight, local_offsets,
self.max_norm, self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse,
local_per_sample_weights, self.include_last_offset, self.padding_idx)
local_output = torch.cat(local_output.split(batch_size), 1)
remains = batch_size % self.world_size
scatter_strides = [batch_size // self.world_size + int(i < remains) for i in range(self.world_size)]
output_full = dual_all_to_all_tablewise(local_output, self.pg, scatter_strides, self.embedding_dim_per_rank)
if shape_hook is not None:
output_full = shape_hook(output_full)
return output_full
def print_comm_stats_(self): def print_comm_stats_(self):
self.cache_weight_mgr.print_comm_stats() self.cache_weight_mgr.print_comm_stats()

View File

@ -253,7 +253,8 @@ def run_parallel_freq_aware_embed_tablewise(rank, world_size):
in KJT format in KJT format
''' '''
res = model(torch.tensor([1, 2, 3, 1, 5, 6, 7, 9, 6, 8, 13, 15, 11], device=device), res = model(torch.tensor([1, 2, 3, 1, 5, 6, 7, 9, 6, 8, 13, 15, 11], device=device),
torch.tensor([0, 3, 3, 5, 7, 8, 10, 10, 12, 13], device=device)) torch.tensor([0, 3, 3, 5, 7, 8, 10, 10, 12, 13], device=device),
already_split_along_rank=False)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2) optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)
rand_grad = torch.rand(3, 5 * 3, dtype=res.dtype, device=res.device) rand_grad = torch.rand(3, 5 * 3, dtype=res.dtype, device=res.device)
if rank == 0: if rank == 0: