mirror of https://github.com/hpcaitech/ColossalAI
[embeddings] add already_split_along_rank flag for tablewise mode (#1584)
parent
77399dc91b
commit
f3403ff98e
|
@ -9,6 +9,7 @@ from colossalai.tensor import ProcessGroup
|
||||||
from colossalai.nn._ops._utils import dual_all_to_all_tablewise
|
from colossalai.nn._ops._utils import dual_all_to_all_tablewise
|
||||||
|
|
||||||
from typing import List
|
from typing import List
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
|
class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
|
||||||
|
@ -79,8 +80,43 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
|
||||||
for rank in self.rank_of_tables:
|
for rank in self.rank_of_tables:
|
||||||
self.embedding_dim_per_rank[rank] += embedding_dim
|
self.embedding_dim_per_rank[rank] += embedding_dim
|
||||||
|
|
||||||
def forward(self, indices: torch.Tensor, offsets: torch.Tensor = None, per_sample_weights=None, shape_hook=None):
|
def forward(self,
|
||||||
batch_size = (offsets.shape[0]) // self.global_tables_num
|
indices: torch.Tensor,
|
||||||
|
offsets: torch.Tensor = None,
|
||||||
|
per_sample_weights=None,
|
||||||
|
shape_hook=None,
|
||||||
|
already_split_along_rank=True):
|
||||||
|
if not already_split_along_rank:
|
||||||
|
# not recommanded. it takes time.
|
||||||
|
batch_size = (offsets.shape[0]) // self.global_tables_num
|
||||||
|
local_indices, local_offsets, local_per_sample_weights = self.split_along_rank(
|
||||||
|
batch_size, indices, offsets, per_sample_weights)
|
||||||
|
else:
|
||||||
|
# recommanded.
|
||||||
|
batch_size = (offsets.shape[0]) // len(self.assigned_table_list)
|
||||||
|
local_indices, local_offsets, local_per_sample_weights = indices, offsets, per_sample_weights
|
||||||
|
with torch.no_grad():
|
||||||
|
reorder_ids = self.cache_weight_mgr.prepare_ids(local_indices)
|
||||||
|
local_output = F.embedding_bag(reorder_ids.cuda(), self.cache_weight_mgr.cuda_cached_weight, local_offsets,
|
||||||
|
self.max_norm, self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse,
|
||||||
|
local_per_sample_weights, self.include_last_offset, self.padding_idx)
|
||||||
|
local_output = torch.cat(local_output.split(batch_size), 1)
|
||||||
|
remains = batch_size % self.world_size
|
||||||
|
scatter_strides = [batch_size // self.world_size + int(i < remains) for i in range(self.world_size)]
|
||||||
|
output_full = dual_all_to_all_tablewise(local_output, self.pg, scatter_strides, self.embedding_dim_per_rank)
|
||||||
|
if shape_hook is not None:
|
||||||
|
output_full = shape_hook(output_full)
|
||||||
|
return output_full
|
||||||
|
|
||||||
|
def split_along_rank(self,
|
||||||
|
batch_size,
|
||||||
|
indices: torch.Tensor,
|
||||||
|
offsets: torch.Tensor = None,
|
||||||
|
per_sample_weights=None):
|
||||||
|
'''
|
||||||
|
if input indices and offsets haven't been splitted along assigned rank, this function will do it.
|
||||||
|
it takes time. please consider splitting data during batch loading.
|
||||||
|
'''
|
||||||
local_indices_list: List(torch.Tensor) = []
|
local_indices_list: List(torch.Tensor) = []
|
||||||
local_offsets_list: List(torch.Tensor) = []
|
local_offsets_list: List(torch.Tensor) = []
|
||||||
if per_sample_weights != None:
|
if per_sample_weights != None:
|
||||||
|
@ -145,20 +181,7 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
|
||||||
local_per_sample_weights = None
|
local_per_sample_weights = None
|
||||||
if per_sample_weights != None:
|
if per_sample_weights != None:
|
||||||
local_per_sample_weights = torch.cat(local_per_sample_weights_list, 0)
|
local_per_sample_weights = torch.cat(local_per_sample_weights_list, 0)
|
||||||
with torch.no_grad():
|
return local_indices, local_offsets, local_per_sample_weights
|
||||||
reorder_ids = self.cache_weight_mgr.prepare_ids(local_indices)
|
|
||||||
|
|
||||||
local_output = F.embedding_bag(reorder_ids.cuda(), self.cache_weight_mgr.cuda_cached_weight, local_offsets,
|
|
||||||
self.max_norm, self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse,
|
|
||||||
local_per_sample_weights, self.include_last_offset, self.padding_idx)
|
|
||||||
local_output = torch.cat(local_output.split(batch_size), 1)
|
|
||||||
|
|
||||||
remains = batch_size % self.world_size
|
|
||||||
scatter_strides = [batch_size // self.world_size + int(i < remains) for i in range(self.world_size)]
|
|
||||||
output_full = dual_all_to_all_tablewise(local_output, self.pg, scatter_strides, self.embedding_dim_per_rank)
|
|
||||||
if shape_hook is not None:
|
|
||||||
output_full = shape_hook(output_full)
|
|
||||||
return output_full
|
|
||||||
|
|
||||||
def print_comm_stats_(self):
|
def print_comm_stats_(self):
|
||||||
self.cache_weight_mgr.print_comm_stats()
|
self.cache_weight_mgr.print_comm_stats()
|
||||||
|
|
|
@ -253,7 +253,8 @@ def run_parallel_freq_aware_embed_tablewise(rank, world_size):
|
||||||
in KJT format
|
in KJT format
|
||||||
'''
|
'''
|
||||||
res = model(torch.tensor([1, 2, 3, 1, 5, 6, 7, 9, 6, 8, 13, 15, 11], device=device),
|
res = model(torch.tensor([1, 2, 3, 1, 5, 6, 7, 9, 6, 8, 13, 15, 11], device=device),
|
||||||
torch.tensor([0, 3, 3, 5, 7, 8, 10, 10, 12, 13], device=device))
|
torch.tensor([0, 3, 3, 5, 7, 8, 10, 10, 12, 13], device=device),
|
||||||
|
already_split_along_rank=False)
|
||||||
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)
|
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)
|
||||||
rand_grad = torch.rand(3, 5 * 3, dtype=res.dtype, device=res.device)
|
rand_grad = torch.rand(3, 5 * 3, dtype=res.dtype, device=res.device)
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
|
|
Loading…
Reference in New Issue