mirror of https://github.com/hpcaitech/ColossalAI
[embeddings] cache option (#1635)
parent
a088022efc
commit
e57df80325
|
@ -97,12 +97,13 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag):
|
|||
evict_strategy=self.evict_strategy)
|
||||
self.cache_weight_mgr.reorder(ids_freq_mapping, warmup_ratio)
|
||||
|
||||
def forward(self, input, offsets=None, per_sample_weights=None, shape_hook=None):
|
||||
with torch.no_grad():
|
||||
reorder_ids = self.cache_weight_mgr.prepare_ids(input)
|
||||
def forward(self, input, offsets=None, per_sample_weights=None, shape_hook=None, cache_op=True):
|
||||
if cache_op:
|
||||
with torch.no_grad():
|
||||
input = self.cache_weight_mgr.prepare_ids(input)
|
||||
|
||||
embeddings = F.embedding_bag(reorder_ids.cuda(), self.cache_weight_mgr.cuda_cached_weight, offsets,
|
||||
self.max_norm, self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse,
|
||||
embeddings = F.embedding_bag(input.cuda(), self.cache_weight_mgr.cuda_cached_weight, offsets, self.max_norm,
|
||||
self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse,
|
||||
per_sample_weights, self.include_last_offset, self.padding_idx)
|
||||
if shape_hook is not None:
|
||||
embeddings = shape_hook(embeddings)
|
||||
|
|
|
@ -72,11 +72,19 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag):
|
|||
compute_attr=ComputePattern.TP1D)
|
||||
return ColoTensor.from_torch_tensor(weight, spec=colo_tensor_spec)
|
||||
|
||||
def forward(self, indices, offsets=None, per_sample_weights=None, shape_hook=None, scatter_dim=0, gather_dim=-1):
|
||||
with torch.no_grad():
|
||||
reorder_ids = self.cache_weight_mgr.prepare_ids(indices)
|
||||
output_shard = F.embedding_bag(reorder_ids.cuda(), self.cache_weight_mgr.cuda_cached_weight, offsets,
|
||||
self.max_norm, self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse,
|
||||
def forward(self,
|
||||
indices,
|
||||
offsets=None,
|
||||
per_sample_weights=None,
|
||||
shape_hook=None,
|
||||
scatter_dim=0,
|
||||
gather_dim=-1,
|
||||
cache_op: bool = True):
|
||||
if cache_op:
|
||||
with torch.no_grad():
|
||||
indices = self.cache_weight_mgr.prepare_ids(indices)
|
||||
output_shard = F.embedding_bag(indices.cuda(), self.cache_weight_mgr.cuda_cached_weight, offsets, self.max_norm,
|
||||
self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse,
|
||||
per_sample_weights, self.include_last_offset, self.padding_idx)
|
||||
if shape_hook is not None:
|
||||
output_shard = shape_hook(output_shard)
|
||||
|
|
|
@ -86,7 +86,8 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
|
|||
offsets: torch.Tensor = None,
|
||||
per_sample_weights=None,
|
||||
shape_hook=None,
|
||||
already_split_along_rank=True):
|
||||
already_split_along_rank=True,
|
||||
cache_op=True):
|
||||
if not already_split_along_rank:
|
||||
# not recommanded. it takes time.
|
||||
batch_size = (offsets.shape[0]) // self.global_tables_num
|
||||
|
@ -96,9 +97,10 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
|
|||
# recommanded.
|
||||
batch_size = (offsets.shape[0]) // len(self.assigned_table_list)
|
||||
local_indices, local_offsets, local_per_sample_weights = indices, offsets, per_sample_weights
|
||||
with torch.no_grad():
|
||||
reorder_ids = self.cache_weight_mgr.prepare_ids(local_indices)
|
||||
local_output = F.embedding_bag(reorder_ids.cuda(), self.cache_weight_mgr.cuda_cached_weight, local_offsets,
|
||||
if cache_op:
|
||||
with torch.no_grad():
|
||||
indices = self.cache_weight_mgr.prepare_ids(local_indices)
|
||||
local_output = F.embedding_bag(indices.cuda(), self.cache_weight_mgr.cuda_cached_weight, local_offsets,
|
||||
self.max_norm, self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse,
|
||||
local_per_sample_weights, self.include_last_offset, self.padding_idx)
|
||||
local_output = torch.cat(local_output.split(batch_size), 1)
|
||||
|
|
Loading…
Reference in New Issue