mirror of https://github.com/hpcaitech/ColossalAI
[embedding] isolate cache_op from forward (#1645)
Co-authored-by: ric <mkkt_bkkt@mail.ustc.edu.cn>pull/1648/head
parent
c5d39215f6
commit
0767f67a0f
|
@ -64,6 +64,7 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag):
|
|||
cuda_row_num = int(num_embeddings * cache_ratio)
|
||||
# configure weight & cache
|
||||
self._preprocess(_weight, cuda_row_num, ids_freq_mapping, warmup_ratio, buffer_size, pin_weight)
|
||||
self.cache_op = True
|
||||
|
||||
def _weight_alloc(self, dtype, device):
|
||||
weight = torch.empty(self.num_embeddings, self.embedding_dim, dtype=dtype, device=device)
|
||||
|
@ -97,8 +98,8 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag):
|
|||
evict_strategy=self.evict_strategy)
|
||||
self.cache_weight_mgr.reorder(ids_freq_mapping, warmup_ratio)
|
||||
|
||||
def forward(self, input, offsets=None, per_sample_weights=None, shape_hook=None, cache_op=True):
|
||||
if cache_op:
|
||||
def forward(self, input, offsets=None, per_sample_weights=None, shape_hook=None):
|
||||
if self.cache_op:
|
||||
with torch.no_grad():
|
||||
input = self.cache_weight_mgr.prepare_ids(input)
|
||||
|
||||
|
@ -119,6 +120,9 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag):
|
|||
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
|
||||
yield self.cache_weight_mgr.cuda_cached_weight
|
||||
|
||||
def set_cache_op(self, cache_op: bool = True):
|
||||
self.cache_op = cache_op
|
||||
|
||||
|
||||
############################# Perf Log ###################################
|
||||
|
||||
|
|
|
@ -60,6 +60,7 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag):
|
|||
self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq,
|
||||
sparse, _weight, mode, include_last_offset, dtype, device, cache_ratio, ids_freq_mapping,
|
||||
warmup_ratio, buffer_size, pin_weight, evict_strategy)
|
||||
self.cache_op = True
|
||||
|
||||
def _weight_alloc(self, dtype, device):
|
||||
weight = torch.empty(self.num_embeddings, self.embedding_dim_per_partition, device=device, dtype=dtype)
|
||||
|
@ -72,15 +73,16 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag):
|
|||
compute_attr=ComputePattern.TP1D)
|
||||
return ColoTensor.from_torch_tensor(weight, spec=colo_tensor_spec)
|
||||
|
||||
def forward(self,
|
||||
indices,
|
||||
offsets=None,
|
||||
per_sample_weights=None,
|
||||
shape_hook=None,
|
||||
scatter_dim=0,
|
||||
gather_dim=-1,
|
||||
cache_op: bool = True):
|
||||
if cache_op:
|
||||
def forward(
|
||||
self,
|
||||
indices,
|
||||
offsets=None,
|
||||
per_sample_weights=None,
|
||||
shape_hook=None,
|
||||
scatter_dim=0,
|
||||
gather_dim=-1,
|
||||
):
|
||||
if self.cache_op:
|
||||
with torch.no_grad():
|
||||
indices = self.cache_weight_mgr.prepare_ids(indices)
|
||||
output_shard = F.embedding_bag(indices.cuda(), self.cache_weight_mgr.cuda_cached_weight, offsets, self.max_norm,
|
||||
|
@ -94,6 +96,9 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag):
|
|||
gather_dim=gather_dim)
|
||||
return output_full
|
||||
|
||||
def set_cache_op(self, cache_op: bool = True):
|
||||
self.cache_op = cache_op
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
|
|
|
@ -81,13 +81,16 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
|
|||
for rank in self.rank_of_tables:
|
||||
self.embedding_dim_per_rank[rank] += embedding_dim
|
||||
|
||||
def forward(self,
|
||||
indices: torch.Tensor,
|
||||
offsets: torch.Tensor = None,
|
||||
per_sample_weights=None,
|
||||
shape_hook=None,
|
||||
already_split_along_rank=True,
|
||||
cache_op=True):
|
||||
self.cache_op = True
|
||||
|
||||
def forward(
|
||||
self,
|
||||
indices: torch.Tensor,
|
||||
offsets: torch.Tensor = None,
|
||||
per_sample_weights=None,
|
||||
shape_hook=None,
|
||||
already_split_along_rank=True,
|
||||
):
|
||||
if not already_split_along_rank:
|
||||
# not recommanded. it takes time.
|
||||
batch_size = (offsets.shape[0]) // self.global_tables_num
|
||||
|
@ -97,7 +100,7 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
|
|||
# recommanded.
|
||||
batch_size = (offsets.shape[0]) // len(self.assigned_table_list)
|
||||
local_indices, local_offsets, local_per_sample_weights = indices, offsets, per_sample_weights
|
||||
if cache_op:
|
||||
if self.cache_op:
|
||||
with torch.no_grad():
|
||||
indices = self.cache_weight_mgr.prepare_ids(local_indices)
|
||||
local_output = F.embedding_bag(indices.cuda(), self.cache_weight_mgr.cuda_cached_weight, local_offsets,
|
||||
|
@ -185,6 +188,9 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
|
|||
local_per_sample_weights = torch.cat(local_per_sample_weights_list, 0)
|
||||
return local_indices, local_offsets, local_per_sample_weights
|
||||
|
||||
def set_cache_op(self, cache_op: bool = True):
|
||||
self.cache_op = cache_op
|
||||
|
||||
def print_comm_stats_(self):
|
||||
self.cache_weight_mgr.print_comm_stats()
|
||||
|
||||
|
|
Loading…
Reference in New Issue