mirror of https://github.com/hpcaitech/ColossalAI
[embedding] isolate cache_op from forward (#1645)
Co-authored-by: ric <mkkt_bkkt@mail.ustc.edu.cn>pull/1648/head
parent
c5d39215f6
commit
0767f67a0f
|
@ -64,6 +64,7 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag):
|
||||||
cuda_row_num = int(num_embeddings * cache_ratio)
|
cuda_row_num = int(num_embeddings * cache_ratio)
|
||||||
# configure weight & cache
|
# configure weight & cache
|
||||||
self._preprocess(_weight, cuda_row_num, ids_freq_mapping, warmup_ratio, buffer_size, pin_weight)
|
self._preprocess(_weight, cuda_row_num, ids_freq_mapping, warmup_ratio, buffer_size, pin_weight)
|
||||||
|
self.cache_op = True
|
||||||
|
|
||||||
def _weight_alloc(self, dtype, device):
|
def _weight_alloc(self, dtype, device):
|
||||||
weight = torch.empty(self.num_embeddings, self.embedding_dim, dtype=dtype, device=device)
|
weight = torch.empty(self.num_embeddings, self.embedding_dim, dtype=dtype, device=device)
|
||||||
|
@ -97,8 +98,8 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag):
|
||||||
evict_strategy=self.evict_strategy)
|
evict_strategy=self.evict_strategy)
|
||||||
self.cache_weight_mgr.reorder(ids_freq_mapping, warmup_ratio)
|
self.cache_weight_mgr.reorder(ids_freq_mapping, warmup_ratio)
|
||||||
|
|
||||||
def forward(self, input, offsets=None, per_sample_weights=None, shape_hook=None, cache_op=True):
|
def forward(self, input, offsets=None, per_sample_weights=None, shape_hook=None):
|
||||||
if cache_op:
|
if self.cache_op:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
input = self.cache_weight_mgr.prepare_ids(input)
|
input = self.cache_weight_mgr.prepare_ids(input)
|
||||||
|
|
||||||
|
@ -119,6 +120,9 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag):
|
||||||
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
|
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
|
||||||
yield self.cache_weight_mgr.cuda_cached_weight
|
yield self.cache_weight_mgr.cuda_cached_weight
|
||||||
|
|
||||||
|
def set_cache_op(self, cache_op: bool = True):
|
||||||
|
self.cache_op = cache_op
|
||||||
|
|
||||||
|
|
||||||
############################# Perf Log ###################################
|
############################# Perf Log ###################################
|
||||||
|
|
||||||
|
|
|
@ -60,6 +60,7 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag):
|
||||||
self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq,
|
self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq,
|
||||||
sparse, _weight, mode, include_last_offset, dtype, device, cache_ratio, ids_freq_mapping,
|
sparse, _weight, mode, include_last_offset, dtype, device, cache_ratio, ids_freq_mapping,
|
||||||
warmup_ratio, buffer_size, pin_weight, evict_strategy)
|
warmup_ratio, buffer_size, pin_weight, evict_strategy)
|
||||||
|
self.cache_op = True
|
||||||
|
|
||||||
def _weight_alloc(self, dtype, device):
|
def _weight_alloc(self, dtype, device):
|
||||||
weight = torch.empty(self.num_embeddings, self.embedding_dim_per_partition, device=device, dtype=dtype)
|
weight = torch.empty(self.num_embeddings, self.embedding_dim_per_partition, device=device, dtype=dtype)
|
||||||
|
@ -72,15 +73,16 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag):
|
||||||
compute_attr=ComputePattern.TP1D)
|
compute_attr=ComputePattern.TP1D)
|
||||||
return ColoTensor.from_torch_tensor(weight, spec=colo_tensor_spec)
|
return ColoTensor.from_torch_tensor(weight, spec=colo_tensor_spec)
|
||||||
|
|
||||||
def forward(self,
|
def forward(
|
||||||
indices,
|
self,
|
||||||
offsets=None,
|
indices,
|
||||||
per_sample_weights=None,
|
offsets=None,
|
||||||
shape_hook=None,
|
per_sample_weights=None,
|
||||||
scatter_dim=0,
|
shape_hook=None,
|
||||||
gather_dim=-1,
|
scatter_dim=0,
|
||||||
cache_op: bool = True):
|
gather_dim=-1,
|
||||||
if cache_op:
|
):
|
||||||
|
if self.cache_op:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
indices = self.cache_weight_mgr.prepare_ids(indices)
|
indices = self.cache_weight_mgr.prepare_ids(indices)
|
||||||
output_shard = F.embedding_bag(indices.cuda(), self.cache_weight_mgr.cuda_cached_weight, offsets, self.max_norm,
|
output_shard = F.embedding_bag(indices.cuda(), self.cache_weight_mgr.cuda_cached_weight, offsets, self.max_norm,
|
||||||
|
@ -94,6 +96,9 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag):
|
||||||
gather_dim=gather_dim)
|
gather_dim=gather_dim)
|
||||||
return output_full
|
return output_full
|
||||||
|
|
||||||
|
def set_cache_op(self, cache_op: bool = True):
|
||||||
|
self.cache_op = cache_op
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(
|
def from_pretrained(
|
||||||
cls,
|
cls,
|
||||||
|
|
|
@ -81,13 +81,16 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
|
||||||
for rank in self.rank_of_tables:
|
for rank in self.rank_of_tables:
|
||||||
self.embedding_dim_per_rank[rank] += embedding_dim
|
self.embedding_dim_per_rank[rank] += embedding_dim
|
||||||
|
|
||||||
def forward(self,
|
self.cache_op = True
|
||||||
indices: torch.Tensor,
|
|
||||||
offsets: torch.Tensor = None,
|
def forward(
|
||||||
per_sample_weights=None,
|
self,
|
||||||
shape_hook=None,
|
indices: torch.Tensor,
|
||||||
already_split_along_rank=True,
|
offsets: torch.Tensor = None,
|
||||||
cache_op=True):
|
per_sample_weights=None,
|
||||||
|
shape_hook=None,
|
||||||
|
already_split_along_rank=True,
|
||||||
|
):
|
||||||
if not already_split_along_rank:
|
if not already_split_along_rank:
|
||||||
# not recommanded. it takes time.
|
# not recommanded. it takes time.
|
||||||
batch_size = (offsets.shape[0]) // self.global_tables_num
|
batch_size = (offsets.shape[0]) // self.global_tables_num
|
||||||
|
@ -97,7 +100,7 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
|
||||||
# recommanded.
|
# recommanded.
|
||||||
batch_size = (offsets.shape[0]) // len(self.assigned_table_list)
|
batch_size = (offsets.shape[0]) // len(self.assigned_table_list)
|
||||||
local_indices, local_offsets, local_per_sample_weights = indices, offsets, per_sample_weights
|
local_indices, local_offsets, local_per_sample_weights = indices, offsets, per_sample_weights
|
||||||
if cache_op:
|
if self.cache_op:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
indices = self.cache_weight_mgr.prepare_ids(local_indices)
|
indices = self.cache_weight_mgr.prepare_ids(local_indices)
|
||||||
local_output = F.embedding_bag(indices.cuda(), self.cache_weight_mgr.cuda_cached_weight, local_offsets,
|
local_output = F.embedding_bag(indices.cuda(), self.cache_weight_mgr.cuda_cached_weight, local_offsets,
|
||||||
|
@ -185,6 +188,9 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
|
||||||
local_per_sample_weights = torch.cat(local_per_sample_weights_list, 0)
|
local_per_sample_weights = torch.cat(local_per_sample_weights_list, 0)
|
||||||
return local_indices, local_offsets, local_per_sample_weights
|
return local_indices, local_offsets, local_per_sample_weights
|
||||||
|
|
||||||
|
def set_cache_op(self, cache_op: bool = True):
|
||||||
|
self.cache_op = cache_op
|
||||||
|
|
||||||
def print_comm_stats_(self):
|
def print_comm_stats_(self):
|
||||||
self.cache_weight_mgr.print_comm_stats()
|
self.cache_weight_mgr.print_comm_stats()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue