mirror of https://github.com/hpcaitech/ColossalAI
[embedding] rename FreqAwareEmbedding -> CachedEmbedding (#1699)
parent
0e52f3d3d5
commit
21962e1593
|
@ -3,12 +3,12 @@ from .linear import ColoLinear
|
|||
from .embedding import ColoEmbedding
|
||||
from .module_utils import register_colo_module, is_colo_module, get_colo_module, init_colo_module, check_colo_module
|
||||
|
||||
from .cache_embedding import FreqAwareEmbeddingBag, ParallelFreqAwareEmbeddingBag, CachedParamMgr, LimitBuffIndexCopyer, EvictionStrategy, \
|
||||
ParallelFreqAwareEmbeddingBagTablewise, TablewiseEmbeddingBagConfig, ParallelFreqAwareEmbeddingBagTablewiseSpiltCache
|
||||
from .cache_embedding import CachedEmbeddingBag, ParallelCachedEmbeddingBag, CachedParamMgr, LimitBuffIndexCopyer, EvictionStrategy, \
|
||||
ParallelCachedEmbeddingBagTablewise, TablewiseEmbeddingBagConfig, ParallelCachedEmbeddingBagTablewiseSpiltCache
|
||||
|
||||
__all__ = [
|
||||
'ColoModule', 'register_colo_module', 'is_colo_module', 'get_colo_module', 'init_colo_module', 'check_colo_module',
|
||||
'ColoLinear', 'ColoEmbedding', 'FreqAwareEmbeddingBag', 'ParallelFreqAwareEmbeddingBag', 'CachedParamMgr',
|
||||
'LimitBuffIndexCopyer', 'EvictionStrategy', 'ParallelFreqAwareEmbeddingBagTablewise', 'TablewiseEmbeddingBagConfig',
|
||||
'ParallelFreqAwareEmbeddingBagTablewiseSpiltCache'
|
||||
'ColoLinear', 'ColoEmbedding', 'CachedEmbeddingBag', 'ParallelCachedEmbeddingBag', 'CachedParamMgr',
|
||||
'LimitBuffIndexCopyer', 'EvictionStrategy', 'ParallelCachedEmbeddingBagTablewise', 'TablewiseEmbeddingBagConfig',
|
||||
'ParallelCachedEmbeddingBagTablewiseSpiltCache'
|
||||
]
|
||||
|
|
|
@ -1,13 +1,13 @@
|
|||
from .cache_mgr import CachedParamMgr, EvictionStrategy
|
||||
from .copyer import LimitBuffIndexCopyer
|
||||
from .freq_aware_embedding import FreqAwareEmbeddingBag
|
||||
from .parallel_freq_aware_embedding import ParallelFreqAwareEmbeddingBag
|
||||
from .cached_embedding import CachedEmbeddingBag
|
||||
from .parallel_cached_embedding import ParallelCachedEmbeddingBag
|
||||
from .embedding_config import TablewiseEmbeddingBagConfig
|
||||
from .parallel_freq_aware_embedding_tablewise import ParallelFreqAwareEmbeddingBagTablewise
|
||||
from .parallel_freq_aware_embedding_tablewise_split_cache import ParallelFreqAwareEmbeddingBagTablewiseSpiltCache
|
||||
from .parallel_cached_embedding_tablewise import ParallelCachedEmbeddingBagTablewise
|
||||
from .parallel_cached_embedding_tablewise_split_cache import ParallelCachedEmbeddingBagTablewiseSpiltCache
|
||||
|
||||
__all__ = [
|
||||
'CachedParamMgr', 'LimitBuffIndexCopyer', 'FreqAwareEmbeddingBag', 'ParallelFreqAwareEmbeddingBag',
|
||||
'EvictionStrategy', 'ParallelFreqAwareEmbeddingBagTablewise', 'TablewiseEmbeddingBagConfig',
|
||||
'ParallelFreqAwareEmbeddingBagTablewiseSpiltCache'
|
||||
'CachedParamMgr', 'LimitBuffIndexCopyer', 'CachedEmbeddingBag', 'ParallelCachedEmbeddingBag', 'EvictionStrategy',
|
||||
'ParallelCachedEmbeddingBagTablewise', 'TablewiseEmbeddingBagConfig',
|
||||
'ParallelCachedEmbeddingBagTablewiseSpiltCache'
|
||||
]
|
||||
|
|
|
@ -352,7 +352,8 @@ class CachedParamMgr(torch.nn.Module):
|
|||
|
||||
# move sure the cuda rows will not be evicted!
|
||||
with record_function("(cache) prepare_rows_on_cuda"):
|
||||
self._prepare_rows_on_cuda(comm_cpu_row_idxs)
|
||||
with self.timer("prepare_rows_on_cuda") as timer:
|
||||
self._prepare_rows_on_cuda(comm_cpu_row_idxs)
|
||||
|
||||
self.evict_backlist = torch.tensor([], device=cpu_row_idxs.device, dtype=cpu_row_idxs.dtype)
|
||||
|
||||
|
|
|
@ -7,10 +7,10 @@ from .cache_mgr import CachedParamMgr, EvictionStrategy
|
|||
from torch.nn.parameter import Parameter
|
||||
|
||||
|
||||
class FreqAwareEmbeddingBag(BaseEmbeddingBag):
|
||||
"""FreqAwareEmbeddingBag
|
||||
class CachedEmbeddingBag(BaseEmbeddingBag):
|
||||
"""CachedEmbeddingBag
|
||||
|
||||
Frequency Aware Embedding. Apply a GPU-based software cache approaches to dynamically manage the embedding table in the CPU and GPU memory space.
|
||||
Cached Embedding. Apply a GPU-based software cache approaches to dynamically manage the embedding table in the CPU and GPU memory space.
|
||||
It can leverage the id's frequency statistics of the target dataset, by passing a frequency list to param `ids_freq_mapping`.
|
||||
You can also apply a navie LFU cache eviction strategy by setting `evict_strategy` as EvictionStrategy.LFU.
|
||||
|
||||
|
@ -54,8 +54,8 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag):
|
|||
buffer_size: int = 0,
|
||||
pin_weight: bool = False,
|
||||
evict_strategy: EvictionStrategy = EvictionStrategy.LFU):
|
||||
super(FreqAwareEmbeddingBag, self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type,
|
||||
scale_grad_by_freq, sparse, mode, include_last_offset)
|
||||
super(CachedEmbeddingBag, self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type,
|
||||
scale_grad_by_freq, sparse, mode, include_last_offset)
|
||||
|
||||
assert cache_ratio <= 1.0, f"cache ratio {cache_ratio} must less than 1.0"
|
||||
self.evict_strategy = evict_strategy
|
|
@ -2,7 +2,7 @@ import torch
|
|||
import torch.nn.functional as F
|
||||
from typing import List, Optional, Iterator, Tuple
|
||||
|
||||
from .freq_aware_embedding import FreqAwareEmbeddingBag
|
||||
from .cached_embedding import CachedEmbeddingBag
|
||||
from colossalai.nn._ops._utils import dual_all_to_all
|
||||
|
||||
from colossalai.tensor import ColoParameter, ShardSpec, ComputePattern, ProcessGroup, ColoTensorSpec, ColoTensor
|
||||
|
@ -28,7 +28,7 @@ def get_partition(embedding_dim, rank, world_size) -> Tuple[int, int, bool]:
|
|||
return offset, offset + size_list[rank], False
|
||||
|
||||
|
||||
class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag):
|
||||
class ParallelCachedEmbeddingBag(CachedEmbeddingBag):
|
||||
|
||||
def __init__(self,
|
||||
num_embeddings,
|
||||
|
@ -56,7 +56,7 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag):
|
|||
embedding_dim, self.rank, self.world_size)
|
||||
self.embedding_dim_per_partition = self.partition_end_index - self.partition_start_index
|
||||
|
||||
super(ParallelFreqAwareEmbeddingBag,
|
||||
super(ParallelCachedEmbeddingBag,
|
||||
self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq,
|
||||
sparse, _weight, mode, include_last_offset, dtype, device, cache_ratio, ids_freq_mapping,
|
||||
warmup_ratio, buffer_size, pin_weight, evict_strategy)
|
||||
|
@ -115,7 +115,7 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag):
|
|||
ids_freq_mapping: Optional[List[int]] = None,
|
||||
warmup_ratio: float = 0.7,
|
||||
buffer_size: int = 0,
|
||||
) -> 'ParallelFreqAwareEmbeddingBag':
|
||||
) -> 'ParallelCachedEmbeddingBag':
|
||||
rows, cols = embedding.shape
|
||||
embedding_bag = cls(rows,
|
||||
cols,
|
|
@ -2,7 +2,7 @@ import torch
|
|||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .freq_aware_embedding import FreqAwareEmbeddingBag
|
||||
from .cached_embedding import CachedEmbeddingBag
|
||||
from .cache_mgr import EvictionStrategy
|
||||
from .embedding_config import TablewiseEmbeddingBagConfig
|
||||
from colossalai.tensor import ProcessGroup
|
||||
|
@ -12,9 +12,9 @@ from typing import List
|
|||
import time
|
||||
|
||||
|
||||
class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
|
||||
class ParallelCachedEmbeddingBagTablewise(CachedEmbeddingBag):
|
||||
"""
|
||||
all tables assigned to this class instance are managed by a single FreqAwareEmbeddingBag.
|
||||
all tables assigned to this class instance are managed by a single CachedEmbeddingBag.
|
||||
Those parameters in TablewiseEmbeddingBagConfig are ignored: cuda_row_num, buffer_size, initial_weight.
|
||||
"""
|
||||
|
||||
|
@ -62,7 +62,7 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
|
|||
self.cache_ratio = cache_ratio
|
||||
# table-associate cache
|
||||
cuda_row_num = int(cache_ratio * self.num_embeddings)
|
||||
super(ParallelFreqAwareEmbeddingBagTablewise,
|
||||
super(ParallelCachedEmbeddingBagTablewise,
|
||||
self).__init__(self.num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq,
|
||||
sparse, _weight, mode, include_last_offset, dtype, device, cache_ratio, ids_freq_mapping,
|
||||
warmup_ratio, buffer_size, pin_weight, evict_strategy)
|
|
@ -3,7 +3,7 @@ import torch.distributed as dist
|
|||
import torch.nn as nn
|
||||
from torch.profiler import record_function
|
||||
|
||||
from .freq_aware_embedding import FreqAwareEmbeddingBag
|
||||
from .cached_embedding import CachedEmbeddingBag
|
||||
|
||||
from colossalai.tensor import ProcessGroup
|
||||
from colossalai.nn._ops._utils import dual_all_to_all_tablewise
|
||||
|
@ -14,9 +14,9 @@ from typing import List
|
|||
import abc
|
||||
|
||||
|
||||
class ParallelFreqAwareEmbeddingBagTablewiseSpiltCache(abc.ABC, nn.Module):
|
||||
class ParallelCachedEmbeddingBagTablewiseSpiltCache(abc.ABC, nn.Module):
|
||||
"""
|
||||
every table assigned to this class instance is managed by a FreqAwareEmbeddingBag.
|
||||
every table assigned to this class instance is managed by a CachedEmbeddingBag.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
@ -34,7 +34,7 @@ class ParallelFreqAwareEmbeddingBagTablewiseSpiltCache(abc.ABC, nn.Module):
|
|||
warmup_ratio=0.7,
|
||||
pin_weight=False,
|
||||
evict_strategy: EvictionStrategy = EvictionStrategy.LFU):
|
||||
super(ParallelFreqAwareEmbeddingBagTablewiseSpiltCache, self).__init__()
|
||||
super(ParallelCachedEmbeddingBagTablewiseSpiltCache, self).__init__()
|
||||
self.rank = dist.get_rank()
|
||||
self.world_size = dist.get_world_size()
|
||||
self.rank_of_tables = [config.assigned_rank for config in embedding_bag_config_list]
|
||||
|
@ -49,31 +49,31 @@ class ParallelFreqAwareEmbeddingBagTablewiseSpiltCache(abc.ABC, nn.Module):
|
|||
self.include_last_offset = include_last_offset
|
||||
self.pg = ProcessGroup(tp_degree=self.world_size)
|
||||
|
||||
# prepare FreqAwareEmbeddingBag list
|
||||
# prepare CachedEmbeddingBag list
|
||||
|
||||
self.freq_aware_embedding_bag_list: nn.ModuleList = nn.ModuleList()
|
||||
self.cached_embedding_bag_list: nn.ModuleList = nn.ModuleList()
|
||||
for config in embedding_bag_config_list:
|
||||
if config.assigned_rank != self.rank:
|
||||
continue
|
||||
self.freq_aware_embedding_bag_list.append(
|
||||
FreqAwareEmbeddingBag(num_embeddings=config.num_embeddings,
|
||||
embedding_dim=embedding_dim,
|
||||
padding_idx=padding_idx,
|
||||
max_norm=max_norm,
|
||||
norm_type=norm_type,
|
||||
scale_grad_by_freq=scale_grad_by_freq,
|
||||
sparse=sparse,
|
||||
_weight=config.initial_weight,
|
||||
mode=mode,
|
||||
include_last_offset=include_last_offset,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
cuda_row_num=config.cuda_row_num,
|
||||
ids_freq_mapping=config.ids_freq_mapping,
|
||||
warmup_ratio=warmup_ratio,
|
||||
buffer_size=config.buffer_size,
|
||||
pin_weight=pin_weight,
|
||||
evict_strategy=evict_strategy))
|
||||
self.cached_embedding_bag_list.append(
|
||||
CachedEmbeddingBag(num_embeddings=config.num_embeddings,
|
||||
embedding_dim=embedding_dim,
|
||||
padding_idx=padding_idx,
|
||||
max_norm=max_norm,
|
||||
norm_type=norm_type,
|
||||
scale_grad_by_freq=scale_grad_by_freq,
|
||||
sparse=sparse,
|
||||
_weight=config.initial_weight,
|
||||
mode=mode,
|
||||
include_last_offset=include_last_offset,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
cuda_row_num=config.cuda_row_num,
|
||||
ids_freq_mapping=config.ids_freq_mapping,
|
||||
warmup_ratio=warmup_ratio,
|
||||
buffer_size=config.buffer_size,
|
||||
pin_weight=pin_weight,
|
||||
evict_strategy=evict_strategy))
|
||||
|
||||
# prepare list shape for all_to_all output
|
||||
self.embedding_dim_per_rank = [0 for i in range(self.world_size)]
|
||||
|
@ -109,8 +109,8 @@ class ParallelFreqAwareEmbeddingBagTablewiseSpiltCache(abc.ABC, nn.Module):
|
|||
if per_sample_weights != None:
|
||||
local_per_sample_weights = per_sample_weights[indices_start_position:indices_end_position]
|
||||
with record_function("(tablewise) tablewise forward"):
|
||||
local_output_list.append(self.freq_aware_embedding_bag_list[i](local_indices, local_offsets,
|
||||
local_per_sample_weights))
|
||||
local_output_list.append(self.cached_embedding_bag_list[i](local_indices, local_offsets,
|
||||
local_per_sample_weights))
|
||||
|
||||
# get result of shape = (batch_size, (len(assigned_table_list)*embedding_dim))
|
||||
local_output = torch.cat(local_output_list, 1)
|
||||
|
@ -126,13 +126,13 @@ class ParallelFreqAwareEmbeddingBagTablewiseSpiltCache(abc.ABC, nn.Module):
|
|||
def element_size(self):
|
||||
if len(self.assigned_table_list) == 0:
|
||||
return 0
|
||||
return self.freq_aware_embedding_bag_list[0].cache_weight_mgr.weight.element_size()
|
||||
return self.cached_embedding_bag_list[0].cache_weight_mgr.weight.element_size()
|
||||
|
||||
def print_comm_stats_(self):
|
||||
cuda_to_cpu_elem_num = 0
|
||||
cpu_to_cuda_elem_num = 0
|
||||
for freq_aware_embedding_bag in self.freq_aware_embedding_bag_list:
|
||||
cuda_to_cpu_elem_num += freq_aware_embedding_bag.cache_weight_mgr._cuda_to_cpu_numel
|
||||
cpu_to_cuda_elem_num += freq_aware_embedding_bag.cache_weight_mgr._cpu_to_cuda_numel
|
||||
for cached_embedding_bag in self.cached_embedding_bag_list:
|
||||
cuda_to_cpu_elem_num += cached_embedding_bag.cache_weight_mgr._cuda_to_cpu_numel
|
||||
cpu_to_cuda_elem_num += cached_embedding_bag.cache_weight_mgr._cpu_to_cuda_numel
|
||||
print(f"CUDA->CPU num: {cuda_to_cpu_elem_num / 1e6} M elem")
|
||||
print(f"CPU->CUDA num: {cpu_to_cuda_elem_num / 1e6} M elem")
|
|
@ -12,8 +12,8 @@ from colossalai.utils import free_port
|
|||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.tensor import ColoParameter, ProcessGroup, ShardSpec, ComputePattern, ComputeSpec, \
|
||||
ColoTensor, ColoTensorSpec
|
||||
from colossalai.nn.parallel.layers import CachedParamMgr, FreqAwareEmbeddingBag, ParallelFreqAwareEmbeddingBag, EvictionStrategy, \
|
||||
ParallelFreqAwareEmbeddingBagTablewise, TablewiseEmbeddingBagConfig
|
||||
from colossalai.nn.parallel.layers import CachedParamMgr, CachedEmbeddingBag, ParallelCachedEmbeddingBag, EvictionStrategy, \
|
||||
ParallelCachedEmbeddingBagTablewise, TablewiseEmbeddingBagConfig
|
||||
from typing import List
|
||||
|
||||
NUM_EMBED, EMBED_DIM = 10, 8
|
||||
|
@ -106,13 +106,13 @@ def test_reorder_with_freq():
|
|||
def test_freq_aware_embed(use_LFU: bool):
|
||||
device = torch.device('cuda', 0)
|
||||
evict_strategy = EvictionStrategy.LFU if use_LFU else EvictionStrategy.DATASET
|
||||
model = FreqAwareEmbeddingBag(NUM_EMBED,
|
||||
EMBED_DIM,
|
||||
mode='mean',
|
||||
include_last_offset=True,
|
||||
cache_ratio=min(BATCH_SIZE * 2 / NUM_EMBED, 1.0),
|
||||
ids_freq_mapping=None,
|
||||
evict_strategy=evict_strategy).to(device)
|
||||
model = CachedEmbeddingBag(NUM_EMBED,
|
||||
EMBED_DIM,
|
||||
mode='mean',
|
||||
include_last_offset=True,
|
||||
cache_ratio=min(BATCH_SIZE * 2 / NUM_EMBED, 1.0),
|
||||
ids_freq_mapping=None,
|
||||
evict_strategy=evict_strategy).to(device)
|
||||
|
||||
assert model.weight.shape[0] == NUM_EMBED
|
||||
ref_model = torch.nn.EmbeddingBag.from_pretrained(model.weight.detach().to(device),
|
||||
|
@ -151,14 +151,14 @@ def test_freq_aware_embed(use_LFU: bool):
|
|||
@pytest.mark.parametrize('init_freq', [True, False])
|
||||
def test_lfu_strategy(init_freq: bool):
|
||||
# minimal test to check behavior
|
||||
Bag = FreqAwareEmbeddingBag(5,
|
||||
5,
|
||||
cache_ratio=3 / 5,
|
||||
buffer_size=0,
|
||||
pin_weight=True,
|
||||
ids_freq_mapping=[4, 2, 1, 3, 1] if init_freq else None,
|
||||
warmup_ratio=1.0,
|
||||
evict_strategy=EvictionStrategy.LFU)
|
||||
Bag = CachedEmbeddingBag(5,
|
||||
5,
|
||||
cache_ratio=3 / 5,
|
||||
buffer_size=0,
|
||||
pin_weight=True,
|
||||
ids_freq_mapping=[4, 2, 1, 3, 1] if init_freq else None,
|
||||
warmup_ratio=1.0,
|
||||
evict_strategy=EvictionStrategy.LFU)
|
||||
|
||||
# print('cached_idx_map: ', Bag.cache_weight_mgr.cached_idx_map)
|
||||
offsets = torch.tensor([0], device="cuda:0")
|
||||
|
@ -233,7 +233,7 @@ def run_parallel_freq_aware_embed_tablewise(rank, world_size):
|
|||
_weight = torch.cat([weight_table1, weight_table2], 0)
|
||||
else:
|
||||
_weight = weight_table3
|
||||
model = ParallelFreqAwareEmbeddingBagTablewise(
|
||||
model = ParallelCachedEmbeddingBagTablewise(
|
||||
embedding_bag_config_list,
|
||||
embedding_dim=5,
|
||||
_weight=_weight,
|
||||
|
@ -300,7 +300,7 @@ def run_parallel_freq_aware_embed_columnwise(rank, world_size):
|
|||
coloweight.set_process_group(ProcessGroup(tp_degree=world_size))
|
||||
coloweight.set_tensor_spec(ShardSpec(dims=[-1], num_partitions=[world_size]), ComputeSpec(ComputePattern.TP1D))
|
||||
|
||||
model = ParallelFreqAwareEmbeddingBag.from_pretrained(
|
||||
model = ParallelCachedEmbeddingBag.from_pretrained(
|
||||
coloweight,
|
||||
include_last_offset=True,
|
||||
freeze=False,
|
||||
|
|
Loading…
Reference in New Issue