ColossalAI/colossalai/nn/parallel/layers/cache_embedding/embedding_config.py

28 lines
855 B
Python
Raw Normal View History

import torch
class TablewiseEmbeddingBagConfig:
'''
example:
def prepare_tablewise_config(args, cache_ratio, ...):
embedding_bag_config_list: List[TablewiseEmbeddingBagConfig] = []
...
return embedding_bag_config_list
'''
def __init__(self,
num_embeddings: int,
cuda_row_num: int,
assigned_rank: int = 0,
buffer_size=50_000,
ids_freq_mapping=None,
initial_weight: torch.tensor = None,
name: str = ""):
self.num_embeddings = num_embeddings
self.cuda_row_num = cuda_row_num
self.assigned_rank = assigned_rank
self.buffer_size = buffer_size
self.ids_freq_mapping = ids_freq_mapping
self.initial_weight = initial_weight
self.name = name