mirror of https://github.com/hpcaitech/ColossalAI
[embeddings] cache option (#1635)
parent
a088022efc
commit
e57df80325
|
@ -97,12 +97,13 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag):
|
||||||
evict_strategy=self.evict_strategy)
|
evict_strategy=self.evict_strategy)
|
||||||
self.cache_weight_mgr.reorder(ids_freq_mapping, warmup_ratio)
|
self.cache_weight_mgr.reorder(ids_freq_mapping, warmup_ratio)
|
||||||
|
|
||||||
def forward(self, input, offsets=None, per_sample_weights=None, shape_hook=None):
|
def forward(self, input, offsets=None, per_sample_weights=None, shape_hook=None, cache_op=True):
|
||||||
with torch.no_grad():
|
if cache_op:
|
||||||
reorder_ids = self.cache_weight_mgr.prepare_ids(input)
|
with torch.no_grad():
|
||||||
|
input = self.cache_weight_mgr.prepare_ids(input)
|
||||||
|
|
||||||
embeddings = F.embedding_bag(reorder_ids.cuda(), self.cache_weight_mgr.cuda_cached_weight, offsets,
|
embeddings = F.embedding_bag(input.cuda(), self.cache_weight_mgr.cuda_cached_weight, offsets, self.max_norm,
|
||||||
self.max_norm, self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse,
|
self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse,
|
||||||
per_sample_weights, self.include_last_offset, self.padding_idx)
|
per_sample_weights, self.include_last_offset, self.padding_idx)
|
||||||
if shape_hook is not None:
|
if shape_hook is not None:
|
||||||
embeddings = shape_hook(embeddings)
|
embeddings = shape_hook(embeddings)
|
||||||
|
|
|
@ -72,11 +72,19 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag):
|
||||||
compute_attr=ComputePattern.TP1D)
|
compute_attr=ComputePattern.TP1D)
|
||||||
return ColoTensor.from_torch_tensor(weight, spec=colo_tensor_spec)
|
return ColoTensor.from_torch_tensor(weight, spec=colo_tensor_spec)
|
||||||
|
|
||||||
def forward(self, indices, offsets=None, per_sample_weights=None, shape_hook=None, scatter_dim=0, gather_dim=-1):
|
def forward(self,
|
||||||
with torch.no_grad():
|
indices,
|
||||||
reorder_ids = self.cache_weight_mgr.prepare_ids(indices)
|
offsets=None,
|
||||||
output_shard = F.embedding_bag(reorder_ids.cuda(), self.cache_weight_mgr.cuda_cached_weight, offsets,
|
per_sample_weights=None,
|
||||||
self.max_norm, self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse,
|
shape_hook=None,
|
||||||
|
scatter_dim=0,
|
||||||
|
gather_dim=-1,
|
||||||
|
cache_op: bool = True):
|
||||||
|
if cache_op:
|
||||||
|
with torch.no_grad():
|
||||||
|
indices = self.cache_weight_mgr.prepare_ids(indices)
|
||||||
|
output_shard = F.embedding_bag(indices.cuda(), self.cache_weight_mgr.cuda_cached_weight, offsets, self.max_norm,
|
||||||
|
self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse,
|
||||||
per_sample_weights, self.include_last_offset, self.padding_idx)
|
per_sample_weights, self.include_last_offset, self.padding_idx)
|
||||||
if shape_hook is not None:
|
if shape_hook is not None:
|
||||||
output_shard = shape_hook(output_shard)
|
output_shard = shape_hook(output_shard)
|
||||||
|
|
|
@ -86,7 +86,8 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
|
||||||
offsets: torch.Tensor = None,
|
offsets: torch.Tensor = None,
|
||||||
per_sample_weights=None,
|
per_sample_weights=None,
|
||||||
shape_hook=None,
|
shape_hook=None,
|
||||||
already_split_along_rank=True):
|
already_split_along_rank=True,
|
||||||
|
cache_op=True):
|
||||||
if not already_split_along_rank:
|
if not already_split_along_rank:
|
||||||
# not recommanded. it takes time.
|
# not recommanded. it takes time.
|
||||||
batch_size = (offsets.shape[0]) // self.global_tables_num
|
batch_size = (offsets.shape[0]) // self.global_tables_num
|
||||||
|
@ -96,9 +97,10 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
|
||||||
# recommanded.
|
# recommanded.
|
||||||
batch_size = (offsets.shape[0]) // len(self.assigned_table_list)
|
batch_size = (offsets.shape[0]) // len(self.assigned_table_list)
|
||||||
local_indices, local_offsets, local_per_sample_weights = indices, offsets, per_sample_weights
|
local_indices, local_offsets, local_per_sample_weights = indices, offsets, per_sample_weights
|
||||||
with torch.no_grad():
|
if cache_op:
|
||||||
reorder_ids = self.cache_weight_mgr.prepare_ids(local_indices)
|
with torch.no_grad():
|
||||||
local_output = F.embedding_bag(reorder_ids.cuda(), self.cache_weight_mgr.cuda_cached_weight, local_offsets,
|
indices = self.cache_weight_mgr.prepare_ids(local_indices)
|
||||||
|
local_output = F.embedding_bag(indices.cuda(), self.cache_weight_mgr.cuda_cached_weight, local_offsets,
|
||||||
self.max_norm, self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse,
|
self.max_norm, self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse,
|
||||||
local_per_sample_weights, self.include_last_offset, self.padding_idx)
|
local_per_sample_weights, self.include_last_offset, self.padding_idx)
|
||||||
local_output = torch.cat(local_output.split(batch_size), 1)
|
local_output = torch.cat(local_output.split(batch_size), 1)
|
||||||
|
|
Loading…
Reference in New Issue