mirror of https://github.com/hpcaitech/ColossalAI
[embedding] polish parallel embedding tablewise (#1545)
parent
46c6cc79a9
commit
64169f3e8f
|
@ -2,8 +2,12 @@ from .cache_mgr import CachedParamMgr, EvictionStrategy
|
||||||
from .copyer import LimitBuffIndexCopyer
|
from .copyer import LimitBuffIndexCopyer
|
||||||
from .freq_aware_embedding import FreqAwareEmbeddingBag
|
from .freq_aware_embedding import FreqAwareEmbeddingBag
|
||||||
from .parallel_freq_aware_embedding import ParallelFreqAwareEmbeddingBag
|
from .parallel_freq_aware_embedding import ParallelFreqAwareEmbeddingBag
|
||||||
from .parallel_freq_aware_embedding_tablewise import ParallelFreqAwareEmbeddingBagTablewise, TablewiseEmbeddingBagConfig, ParallelFreqAwareEmbeddingBagTablewiseSpiltCache
|
from .embedding_config import TablewiseEmbeddingBagConfig
|
||||||
|
from .parallel_freq_aware_embedding_tablewise import ParallelFreqAwareEmbeddingBagTablewise
|
||||||
|
from .parallel_freq_aware_embedding_tablewise_split_cache import ParallelFreqAwareEmbeddingBagTablewiseSpiltCache
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'CachedParamMgr', 'LimitBuffIndexCopyer', 'FreqAwareEmbeddingBag', 'ParallelFreqAwareEmbeddingBag',
|
'CachedParamMgr', 'LimitBuffIndexCopyer', 'FreqAwareEmbeddingBag', 'ParallelFreqAwareEmbeddingBag',
|
||||||
'EvictionStrategy', 'ParallelFreqAwareEmbeddingBagTablewise', 'TablewiseEmbeddingBagConfig', 'ParallelFreqAwareEmbeddingBagTablewiseSpiltCache'
|
'EvictionStrategy', 'ParallelFreqAwareEmbeddingBagTablewise', 'TablewiseEmbeddingBagConfig',
|
||||||
|
'ParallelFreqAwareEmbeddingBagTablewiseSpiltCache'
|
||||||
]
|
]
|
||||||
|
|
|
@ -293,7 +293,7 @@ class CachedParamMgr(torch.nn.Module):
|
||||||
Returns:
|
Returns:
|
||||||
torch.Tensor: indices on the cuda_cached_weight.
|
torch.Tensor: indices on the cuda_cached_weight.
|
||||||
"""
|
"""
|
||||||
with record_function("(zhg) get unique indices"):
|
with record_function("(pre-id) get unique indices"):
|
||||||
ids = ids.to(self._cache_dev)
|
ids = ids.to(self._cache_dev)
|
||||||
cpu_row_idxs, repeat_times = torch.unique(self.idx_map.index_select(0, ids), return_counts=True)
|
cpu_row_idxs, repeat_times = torch.unique(self.idx_map.index_select(0, ids), return_counts=True)
|
||||||
|
|
||||||
|
@ -303,7 +303,7 @@ class CachedParamMgr(torch.nn.Module):
|
||||||
f"Please increase cuda_row_num or decrease the training batch size."
|
f"Please increase cuda_row_num or decrease the training batch size."
|
||||||
self.evict_backlist = cpu_row_idxs
|
self.evict_backlist = cpu_row_idxs
|
||||||
|
|
||||||
with record_function("(zhg) get cpu row idxs"):
|
with record_function("(pre-id) get cpu row idxs"):
|
||||||
comm_cpu_row_idxs = cpu_row_idxs[torch.isin(cpu_row_idxs, self.cached_idx_map, invert=True)]
|
comm_cpu_row_idxs = cpu_row_idxs[torch.isin(cpu_row_idxs, self.cached_idx_map, invert=True)]
|
||||||
|
|
||||||
self.num_hits_history.append(len(cpu_row_idxs) - len(comm_cpu_row_idxs))
|
self.num_hits_history.append(len(cpu_row_idxs) - len(comm_cpu_row_idxs))
|
||||||
|
@ -311,18 +311,18 @@ class CachedParamMgr(torch.nn.Module):
|
||||||
self.num_write_back_history.append(0)
|
self.num_write_back_history.append(0)
|
||||||
|
|
||||||
# move sure the cuda rows will not be evicted!
|
# move sure the cuda rows will not be evicted!
|
||||||
with record_function("(zhg) cache update"):
|
with record_function("(pre-id) cache update"):
|
||||||
self._prepare_rows_on_cuda(comm_cpu_row_idxs)
|
self._prepare_rows_on_cuda(comm_cpu_row_idxs)
|
||||||
|
self.evict_backlist = torch.tensor([], device=cpu_row_idxs.device, dtype=cpu_row_idxs.dtype)
|
||||||
|
|
||||||
self.evict_backlist = torch.tensor([], device=cpu_row_idxs.device, dtype=cpu_row_idxs.dtype)
|
with record_function("(pre-id) embed cpu rows idx -> cache gpu row idxs"):
|
||||||
|
|
||||||
with record_function("(zhg) embed cpu rows idx -> cache gpu row idxs"):
|
|
||||||
gpu_row_idxs = self._id_to_cached_cuda_id(ids)
|
gpu_row_idxs = self._id_to_cached_cuda_id(ids)
|
||||||
|
|
||||||
# update for LFU.
|
# update for LFU.
|
||||||
if self._evict_strategy == EvictionStrategy.LFU:
|
if self._evict_strategy == EvictionStrategy.LFU:
|
||||||
unique_gpu_row_idxs = self.inverted_cached_idx[cpu_row_idxs]
|
with record_function("(pre-id) lfu cnter updates"):
|
||||||
self.freq_cnter.scatter_add_(0, unique_gpu_row_idxs, repeat_times)
|
unique_gpu_row_idxs = self.inverted_cached_idx[cpu_row_idxs]
|
||||||
|
self.freq_cnter.scatter_add_(0, unique_gpu_row_idxs, repeat_times)
|
||||||
|
|
||||||
return gpu_row_idxs
|
return gpu_row_idxs
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,27 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class TablewiseEmbeddingBagConfig:
|
||||||
|
'''
|
||||||
|
example:
|
||||||
|
def prepare_tablewise_config(args, cache_ratio, ...):
|
||||||
|
embedding_bag_config_list: List[TablewiseEmbeddingBagConfig] = []
|
||||||
|
...
|
||||||
|
return embedding_bag_config_list
|
||||||
|
'''
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
num_embeddings: int,
|
||||||
|
cuda_row_num: int,
|
||||||
|
assigned_rank: int = 0,
|
||||||
|
buffer_size=50_000,
|
||||||
|
ids_freq_mapping=None,
|
||||||
|
initial_weight: torch.tensor = None,
|
||||||
|
name: str = ""):
|
||||||
|
self.num_embeddings = num_embeddings
|
||||||
|
self.cuda_row_num = cuda_row_num
|
||||||
|
self.assigned_rank = assigned_rank
|
||||||
|
self.buffer_size = buffer_size
|
||||||
|
self.ids_freq_mapping = ids_freq_mapping
|
||||||
|
self.initial_weight = initial_weight
|
||||||
|
self.name = name
|
|
@ -1,42 +1,14 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn as nn
|
|
||||||
from torch.profiler import record_function
|
|
||||||
from typing import List
|
|
||||||
import abc
|
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from .freq_aware_embedding import FreqAwareEmbeddingBag
|
from .freq_aware_embedding import FreqAwareEmbeddingBag
|
||||||
|
|
||||||
from colossalai.tensor import ProcessGroup
|
|
||||||
from .cache_mgr import EvictionStrategy
|
from .cache_mgr import EvictionStrategy
|
||||||
|
from .embedding_config import TablewiseEmbeddingBagConfig
|
||||||
|
from colossalai.tensor import ProcessGroup
|
||||||
from colossalai.nn._ops._utils import dual_all_to_all_tablewise
|
from colossalai.nn._ops._utils import dual_all_to_all_tablewise
|
||||||
|
|
||||||
|
from typing import List
|
||||||
class TablewiseEmbeddingBagConfig:
|
|
||||||
'''
|
|
||||||
example:
|
|
||||||
def prepare_tablewise_config(args, cache_ratio, ...):
|
|
||||||
embedding_bag_config_list: List[TablewiseEmbeddingBagConfig] = []
|
|
||||||
...
|
|
||||||
return embedding_bag_config_list
|
|
||||||
'''
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
num_embeddings: int,
|
|
||||||
cuda_row_num: int,
|
|
||||||
assigned_rank: int = 0,
|
|
||||||
buffer_size=50_000,
|
|
||||||
ids_freq_mapping=None,
|
|
||||||
initial_weight: torch.tensor = None,
|
|
||||||
name: str = ""):
|
|
||||||
self.num_embeddings = num_embeddings
|
|
||||||
self.cuda_row_num = cuda_row_num
|
|
||||||
self.assigned_rank = assigned_rank
|
|
||||||
self.buffer_size = buffer_size
|
|
||||||
self.ids_freq_mapping = ids_freq_mapping
|
|
||||||
self.initial_weight = initial_weight
|
|
||||||
self.name = name
|
|
||||||
|
|
||||||
|
|
||||||
class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
|
class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
|
||||||
|
@ -44,6 +16,7 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
|
||||||
all tables assigned to this class instance are managed by a single FreqAwareEmbeddingBag.
|
all tables assigned to this class instance are managed by a single FreqAwareEmbeddingBag.
|
||||||
Those parameters in TablewiseEmbeddingBagConfig are ignored: cuda_row_num, buffer_size, initial_weight.
|
Those parameters in TablewiseEmbeddingBagConfig are ignored: cuda_row_num, buffer_size, initial_weight.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
embedding_bag_config_list: List[TablewiseEmbeddingBagConfig],
|
embedding_bag_config_list: List[TablewiseEmbeddingBagConfig],
|
||||||
embedding_dim: int,
|
embedding_dim: int,
|
||||||
|
@ -98,9 +71,9 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
|
||||||
for table_i, table_num_embeddings in enumerate(self.global_table_num_embeddings_list):
|
for table_i, table_num_embeddings in enumerate(self.global_table_num_embeddings_list):
|
||||||
if self.rank_of_tables[table_i] == self.rank:
|
if self.rank_of_tables[table_i] == self.rank:
|
||||||
self.idx_offset_list.append(offset_cumsum)
|
self.idx_offset_list.append(offset_cumsum)
|
||||||
else :
|
else:
|
||||||
offset_cumsum += table_num_embeddings
|
offset_cumsum += table_num_embeddings
|
||||||
|
|
||||||
# prepare list shape for all_to_all output
|
# prepare list shape for all_to_all output
|
||||||
self.embedding_dim_per_rank = [0 for i in range(self.world_size)]
|
self.embedding_dim_per_rank = [0 for i in range(self.world_size)]
|
||||||
for rank in self.rank_of_tables:
|
for rank in self.rank_of_tables:
|
||||||
|
@ -112,8 +85,8 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
|
||||||
local_offsets_list: List(torch.Tensor) = []
|
local_offsets_list: List(torch.Tensor) = []
|
||||||
if per_sample_weights != None:
|
if per_sample_weights != None:
|
||||||
local_per_sample_weights_list: List(torch.Tensor) = []
|
local_per_sample_weights_list: List(torch.Tensor) = []
|
||||||
|
|
||||||
offset_pre_end = 0 # local_offsets trick
|
offset_pre_end = 0 # local_offsets trick
|
||||||
for i, handle_table in enumerate(self.assigned_table_list):
|
for i, handle_table in enumerate(self.assigned_table_list):
|
||||||
indices_start_position = offsets[batch_size * handle_table]
|
indices_start_position = offsets[batch_size * handle_table]
|
||||||
if (not self.include_last_offset) and (batch_size * (handle_table + 1) >= indices.shape[0]):
|
if (not self.include_last_offset) and (batch_size * (handle_table + 1) >= indices.shape[0]):
|
||||||
|
@ -122,27 +95,29 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
|
||||||
else:
|
else:
|
||||||
indices_end_position = offsets[batch_size * (handle_table + 1)]
|
indices_end_position = offsets[batch_size * (handle_table + 1)]
|
||||||
# 1. local_indices_list:
|
# 1. local_indices_list:
|
||||||
local_indices_list.append(indices.narrow(0, indices_start_position, indices_end_position
|
local_indices_list.append(
|
||||||
- indices_start_position).sub(self.idx_offset_list[i]))
|
indices.narrow(0, indices_start_position,
|
||||||
|
indices_end_position - indices_start_position).sub(self.idx_offset_list[i]))
|
||||||
# 2. local_offsets_list:
|
# 2. local_offsets_list:
|
||||||
if i + 1 == len(self.assigned_table_list):
|
if i + 1 == len(self.assigned_table_list):
|
||||||
# till-the-end special case
|
# till-the-end special case
|
||||||
if not self.include_last_offset:
|
if not self.include_last_offset:
|
||||||
local_offsets = offsets.narrow(0, batch_size * handle_table,
|
local_offsets = offsets.narrow(0, batch_size * handle_table,
|
||||||
batch_size).add(offset_pre_end - offsets[batch_size * (handle_table)])
|
batch_size).add(offset_pre_end - offsets[batch_size *
|
||||||
else :
|
(handle_table)])
|
||||||
local_offsets = offsets.narrow(0, batch_size * handle_table,
|
else:
|
||||||
batch_size + 1).add(offset_pre_end - offsets[batch_size * (handle_table)])
|
local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size +
|
||||||
|
1).add(offset_pre_end - offsets[batch_size * (handle_table)])
|
||||||
local_offsets_list.append(local_offsets)
|
local_offsets_list.append(local_offsets)
|
||||||
else:
|
else:
|
||||||
local_offsets = offsets.narrow(0, batch_size * handle_table,
|
local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size +
|
||||||
batch_size + 1).add(offset_pre_end - offsets[batch_size * (handle_table)])
|
1).add(offset_pre_end - offsets[batch_size * (handle_table)])
|
||||||
offset_pre_end = local_offsets[-1]
|
offset_pre_end = local_offsets[-1]
|
||||||
local_offsets_list.append(local_offsets[:-1])
|
local_offsets_list.append(local_offsets[:-1])
|
||||||
# 3. local_per_sample_weights_list:
|
# 3. local_per_sample_weights_list:
|
||||||
if per_sample_weights != None:
|
if per_sample_weights != None:
|
||||||
local_per_sample_weights_list.append(per_sample_weights[indices_start_position:indices_end_position])
|
local_per_sample_weights_list.append(per_sample_weights[indices_start_position:indices_end_position])
|
||||||
|
|
||||||
local_indices = torch.cat(local_indices_list, 0)
|
local_indices = torch.cat(local_indices_list, 0)
|
||||||
local_offsets = torch.cat(local_offsets_list, 0)
|
local_offsets = torch.cat(local_offsets_list, 0)
|
||||||
local_per_sample_weights = None
|
local_per_sample_weights = None
|
||||||
|
@ -150,148 +125,21 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
|
||||||
local_per_sample_weights = torch.cat(local_per_sample_weights_list, 0)
|
local_per_sample_weights = torch.cat(local_per_sample_weights_list, 0)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
reorder_ids = self.cache_weight_mgr.prepare_ids(local_indices)
|
reorder_ids = self.cache_weight_mgr.prepare_ids(local_indices)
|
||||||
|
|
||||||
local_output = F.embedding_bag(reorder_ids.cuda(), self.cache_weight_mgr.cuda_cached_weight, local_offsets,
|
local_output = F.embedding_bag(reorder_ids.cuda(), self.cache_weight_mgr.cuda_cached_weight, local_offsets,
|
||||||
self.max_norm, self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse,
|
self.max_norm, self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse,
|
||||||
local_per_sample_weights, self.include_last_offset, self.padding_idx)
|
local_per_sample_weights, self.include_last_offset, self.padding_idx)
|
||||||
local_output = torch.cat(local_output.split(batch_size),1)
|
local_output = torch.cat(local_output.split(batch_size), 1)
|
||||||
|
|
||||||
remains = batch_size % self.world_size
|
remains = batch_size % self.world_size
|
||||||
scatter_strides = [batch_size // self.world_size + int(i < remains) for i in range(self.world_size)]
|
scatter_strides = [batch_size // self.world_size + int(i < remains) for i in range(self.world_size)]
|
||||||
output_full = dual_all_to_all_tablewise(local_output, self.pg, scatter_strides, self.embedding_dim_per_rank)
|
output_full = dual_all_to_all_tablewise(local_output, self.pg, scatter_strides, self.embedding_dim_per_rank)
|
||||||
if shape_hook is not None:
|
if shape_hook is not None:
|
||||||
output_full = shape_hook(output_full)
|
output_full = shape_hook(output_full)
|
||||||
return output_full
|
return output_full
|
||||||
|
|
||||||
def print_comm_stats_(self):
|
def print_comm_stats_(self):
|
||||||
self.cache_weight_mgr.print_comm_stats()
|
self.cache_weight_mgr.print_comm_stats()
|
||||||
|
|
||||||
def element_size(self):
|
def element_size(self):
|
||||||
return self.weight.element_size()
|
return self.weight.element_size()
|
||||||
|
|
||||||
class ParallelFreqAwareEmbeddingBagTablewiseSpiltCache(abc.ABC, nn.Module):
|
|
||||||
"""
|
|
||||||
every table assigned to this class instance is managed by a FreqAwareEmbeddingBag.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
embedding_bag_config_list: List[TablewiseEmbeddingBagConfig],
|
|
||||||
embedding_dim: int,
|
|
||||||
padding_idx=None,
|
|
||||||
max_norm=None,
|
|
||||||
norm_type=2.,
|
|
||||||
scale_grad_by_freq=False,
|
|
||||||
sparse=False,
|
|
||||||
mode='mean',
|
|
||||||
include_last_offset=False,
|
|
||||||
dtype=None,
|
|
||||||
device=None,
|
|
||||||
warmup_ratio=0.7,
|
|
||||||
pin_weight=False,
|
|
||||||
evict_strategy: EvictionStrategy = EvictionStrategy.LFU):
|
|
||||||
super(ParallelFreqAwareEmbeddingBagTablewiseSpiltCache, self).__init__()
|
|
||||||
self.rank = dist.get_rank()
|
|
||||||
self.world_size = dist.get_world_size()
|
|
||||||
self.rank_of_tables = [config.assigned_rank for config in embedding_bag_config_list]
|
|
||||||
self.global_table_num_embeddings_list = [config.num_embeddings for config in embedding_bag_config_list]
|
|
||||||
self.global_tables_num = len(embedding_bag_config_list)
|
|
||||||
self.global_tables_offsets = torch.cumsum(torch.tensor([0] + self.global_table_num_embeddings_list), 0).cuda()
|
|
||||||
|
|
||||||
self.assigned_table_list: List[int] = []
|
|
||||||
for i, rank in enumerate(self.rank_of_tables):
|
|
||||||
if rank == self.rank:
|
|
||||||
self.assigned_table_list.append(i)
|
|
||||||
self.include_last_offset = include_last_offset
|
|
||||||
self.pg = ProcessGroup(tp_degree=self.world_size)
|
|
||||||
|
|
||||||
# prepare FreqAwareEmbeddingBag list
|
|
||||||
|
|
||||||
self.freq_aware_embedding_bag_list: nn.ModuleList = nn.ModuleList()
|
|
||||||
for config in embedding_bag_config_list:
|
|
||||||
if config.assigned_rank != self.rank:
|
|
||||||
continue
|
|
||||||
self.freq_aware_embedding_bag_list.append(
|
|
||||||
FreqAwareEmbeddingBag(num_embeddings=config.num_embeddings,
|
|
||||||
embedding_dim=embedding_dim,
|
|
||||||
padding_idx=padding_idx,
|
|
||||||
max_norm=max_norm,
|
|
||||||
norm_type=norm_type,
|
|
||||||
scale_grad_by_freq=scale_grad_by_freq,
|
|
||||||
sparse=sparse,
|
|
||||||
_weight=config.initial_weight,
|
|
||||||
mode=mode,
|
|
||||||
include_last_offset=include_last_offset,
|
|
||||||
dtype=dtype,
|
|
||||||
device=device,
|
|
||||||
cuda_row_num=config.cuda_row_num,
|
|
||||||
ids_freq_mapping=config.ids_freq_mapping,
|
|
||||||
warmup_ratio=warmup_ratio,
|
|
||||||
buffer_size=config.buffer_size,
|
|
||||||
pin_weight=pin_weight,
|
|
||||||
evict_strategy=evict_strategy))
|
|
||||||
|
|
||||||
# prepare list shape for all_to_all output
|
|
||||||
self.embedding_dim_per_rank = [0 for i in range(self.world_size)]
|
|
||||||
for rank in self.rank_of_tables:
|
|
||||||
self.embedding_dim_per_rank[rank] += embedding_dim
|
|
||||||
|
|
||||||
def forward(self, indices: torch.Tensor, offsets: torch.Tensor = None, per_sample_weights=None, shape_hook=None):
|
|
||||||
# determine indices to handle
|
|
||||||
batch_size = (offsets.shape[0]) // self.global_tables_num
|
|
||||||
local_output_list = []
|
|
||||||
for i, handle_table in enumerate(self.assigned_table_list):
|
|
||||||
with record_function("(tablewise) prepare indices and offsets"):
|
|
||||||
with record_function("part 1"):
|
|
||||||
indices_start_position = offsets[batch_size * handle_table]
|
|
||||||
if (not self.include_last_offset) and (batch_size * (handle_table + 1) >= indices.shape[0]):
|
|
||||||
# till the end special case
|
|
||||||
indices_end_position = indices.shape[0]
|
|
||||||
else:
|
|
||||||
indices_end_position = offsets[batch_size * (handle_table + 1)]
|
|
||||||
with record_function("part 2"):
|
|
||||||
# local_indices = indices[indices_start_position:indices_end_position] - self.global_tables_offsets[handle_table]
|
|
||||||
local_indices = indices.narrow(0, indices_start_position, indices_end_position
|
|
||||||
- indices_start_position).sub(self.global_tables_offsets[handle_table])
|
|
||||||
if self.include_last_offset:
|
|
||||||
# local_offsets = offsets[batch_size * handle_table:batch_size * (handle_table + 1) + 1] - offsets[batch_size * (handle_table)]
|
|
||||||
local_offsets = offsets.narrow(0, batch_size * handle_table,
|
|
||||||
batch_size + 1).sub(offsets[batch_size * (handle_table)])
|
|
||||||
else:
|
|
||||||
# local_offsets = offsets[batch_size * handle_table:batch_size * (handle_table + 1)] - offsets[batch_size * (handle_table)]
|
|
||||||
local_offsets = offsets.narrow(0, batch_size * handle_table,
|
|
||||||
batch_size).sub(offsets[batch_size * (handle_table)])
|
|
||||||
local_per_sample_weights = None
|
|
||||||
if per_sample_weights != None:
|
|
||||||
local_per_sample_weights = per_sample_weights[indices_start_position:indices_end_position]
|
|
||||||
with record_function("(tablewise) tablewise forward"):
|
|
||||||
local_output_list.append(self.freq_aware_embedding_bag_list[i](local_indices, local_offsets,
|
|
||||||
local_per_sample_weights))
|
|
||||||
|
|
||||||
# get result of shape = (batch_size, (len(assigned_table_list)*embedding_dim))
|
|
||||||
local_output = torch.cat(local_output_list, 1)
|
|
||||||
# then concatenate those local_output on the second demension.
|
|
||||||
# use all_to_all
|
|
||||||
remains = batch_size % self.world_size
|
|
||||||
scatter_strides = [batch_size // self.world_size + int(i < remains) for i in range(self.world_size)]
|
|
||||||
output_full = dual_all_to_all_tablewise(local_output, self.pg, scatter_strides, self.embedding_dim_per_rank)
|
|
||||||
if shape_hook is not None:
|
|
||||||
output_full = shape_hook(output_full)
|
|
||||||
return output_full
|
|
||||||
|
|
||||||
def element_size(self):
|
|
||||||
if len(self.assigned_table_list) == 0:
|
|
||||||
return 0
|
|
||||||
return self.freq_aware_embedding_bag_list[0].cache_weight_mgr.weight.element_size()
|
|
||||||
|
|
||||||
def print_comm_stats_(self):
|
|
||||||
cuda_to_cpu_elem_num = 0
|
|
||||||
cpu_to_cuda_elem_num = 0
|
|
||||||
for freq_aware_embedding_bag in self.freq_aware_embedding_bag_list:
|
|
||||||
cuda_to_cpu_elem_num += freq_aware_embedding_bag.cache_weight_mgr._cuda_to_cpu_numel
|
|
||||||
cpu_to_cuda_elem_num += freq_aware_embedding_bag.cache_weight_mgr._cpu_to_cuda_numel
|
|
||||||
print(
|
|
||||||
f"CUDA->CPU num: {cuda_to_cpu_elem_num / 1e6} M elem"
|
|
||||||
)
|
|
||||||
print(
|
|
||||||
f"CPU->CUDA num: {cpu_to_cuda_elem_num / 1e6} M elem"
|
|
||||||
)
|
|
||||||
|
|
|
@ -0,0 +1,138 @@
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.profiler import record_function
|
||||||
|
|
||||||
|
from .freq_aware_embedding import FreqAwareEmbeddingBag
|
||||||
|
|
||||||
|
from colossalai.tensor import ProcessGroup
|
||||||
|
from colossalai.nn._ops._utils import dual_all_to_all_tablewise
|
||||||
|
from .embedding_config import TablewiseEmbeddingBagConfig
|
||||||
|
from .cache_mgr import EvictionStrategy
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
import abc
|
||||||
|
|
||||||
|
|
||||||
|
class ParallelFreqAwareEmbeddingBagTablewiseSpiltCache(abc.ABC, nn.Module):
|
||||||
|
"""
|
||||||
|
every table assigned to this class instance is managed by a FreqAwareEmbeddingBag.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
embedding_bag_config_list: List[TablewiseEmbeddingBagConfig],
|
||||||
|
embedding_dim: int,
|
||||||
|
padding_idx=None,
|
||||||
|
max_norm=None,
|
||||||
|
norm_type=2.,
|
||||||
|
scale_grad_by_freq=False,
|
||||||
|
sparse=False,
|
||||||
|
mode='mean',
|
||||||
|
include_last_offset=False,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
warmup_ratio=0.7,
|
||||||
|
pin_weight=False,
|
||||||
|
evict_strategy: EvictionStrategy = EvictionStrategy.LFU):
|
||||||
|
super(ParallelFreqAwareEmbeddingBagTablewiseSpiltCache, self).__init__()
|
||||||
|
self.rank = dist.get_rank()
|
||||||
|
self.world_size = dist.get_world_size()
|
||||||
|
self.rank_of_tables = [config.assigned_rank for config in embedding_bag_config_list]
|
||||||
|
self.global_table_num_embeddings_list = [config.num_embeddings for config in embedding_bag_config_list]
|
||||||
|
self.global_tables_num = len(embedding_bag_config_list)
|
||||||
|
self.global_tables_offsets = torch.cumsum(torch.tensor([0] + self.global_table_num_embeddings_list), 0).cuda()
|
||||||
|
|
||||||
|
self.assigned_table_list: List[int] = []
|
||||||
|
for i, rank in enumerate(self.rank_of_tables):
|
||||||
|
if rank == self.rank:
|
||||||
|
self.assigned_table_list.append(i)
|
||||||
|
self.include_last_offset = include_last_offset
|
||||||
|
self.pg = ProcessGroup(tp_degree=self.world_size)
|
||||||
|
|
||||||
|
# prepare FreqAwareEmbeddingBag list
|
||||||
|
|
||||||
|
self.freq_aware_embedding_bag_list: nn.ModuleList = nn.ModuleList()
|
||||||
|
for config in embedding_bag_config_list:
|
||||||
|
if config.assigned_rank != self.rank:
|
||||||
|
continue
|
||||||
|
self.freq_aware_embedding_bag_list.append(
|
||||||
|
FreqAwareEmbeddingBag(num_embeddings=config.num_embeddings,
|
||||||
|
embedding_dim=embedding_dim,
|
||||||
|
padding_idx=padding_idx,
|
||||||
|
max_norm=max_norm,
|
||||||
|
norm_type=norm_type,
|
||||||
|
scale_grad_by_freq=scale_grad_by_freq,
|
||||||
|
sparse=sparse,
|
||||||
|
_weight=config.initial_weight,
|
||||||
|
mode=mode,
|
||||||
|
include_last_offset=include_last_offset,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
cuda_row_num=config.cuda_row_num,
|
||||||
|
ids_freq_mapping=config.ids_freq_mapping,
|
||||||
|
warmup_ratio=warmup_ratio,
|
||||||
|
buffer_size=config.buffer_size,
|
||||||
|
pin_weight=pin_weight,
|
||||||
|
evict_strategy=evict_strategy))
|
||||||
|
|
||||||
|
# prepare list shape for all_to_all output
|
||||||
|
self.embedding_dim_per_rank = [0 for i in range(self.world_size)]
|
||||||
|
for rank in self.rank_of_tables:
|
||||||
|
self.embedding_dim_per_rank[rank] += embedding_dim
|
||||||
|
|
||||||
|
def forward(self, indices: torch.Tensor, offsets: torch.Tensor = None, per_sample_weights=None, shape_hook=None):
|
||||||
|
# determine indices to handle
|
||||||
|
batch_size = (offsets.shape[0]) // self.global_tables_num
|
||||||
|
local_output_list = []
|
||||||
|
for i, handle_table in enumerate(self.assigned_table_list):
|
||||||
|
with record_function("(tablewise) prepare indices and offsets"):
|
||||||
|
with record_function("part 1"):
|
||||||
|
indices_start_position = offsets[batch_size * handle_table]
|
||||||
|
if (not self.include_last_offset) and (batch_size * (handle_table + 1) >= indices.shape[0]):
|
||||||
|
# till the end special case
|
||||||
|
indices_end_position = indices.shape[0]
|
||||||
|
else:
|
||||||
|
indices_end_position = offsets[batch_size * (handle_table + 1)]
|
||||||
|
with record_function("part 2"):
|
||||||
|
# local_indices = indices[indices_start_position:indices_end_position] - self.global_tables_offsets[handle_table]
|
||||||
|
local_indices = indices.narrow(0, indices_start_position, indices_end_position -
|
||||||
|
indices_start_position).sub(self.global_tables_offsets[handle_table])
|
||||||
|
if self.include_last_offset:
|
||||||
|
# local_offsets = offsets[batch_size * handle_table:batch_size * (handle_table + 1) + 1] - offsets[batch_size * (handle_table)]
|
||||||
|
local_offsets = offsets.narrow(0, batch_size * handle_table,
|
||||||
|
batch_size + 1).sub(offsets[batch_size * (handle_table)])
|
||||||
|
else:
|
||||||
|
# local_offsets = offsets[batch_size * handle_table:batch_size * (handle_table + 1)] - offsets[batch_size * (handle_table)]
|
||||||
|
local_offsets = offsets.narrow(0, batch_size * handle_table,
|
||||||
|
batch_size).sub(offsets[batch_size * (handle_table)])
|
||||||
|
local_per_sample_weights = None
|
||||||
|
if per_sample_weights != None:
|
||||||
|
local_per_sample_weights = per_sample_weights[indices_start_position:indices_end_position]
|
||||||
|
with record_function("(tablewise) tablewise forward"):
|
||||||
|
local_output_list.append(self.freq_aware_embedding_bag_list[i](local_indices, local_offsets,
|
||||||
|
local_per_sample_weights))
|
||||||
|
|
||||||
|
# get result of shape = (batch_size, (len(assigned_table_list)*embedding_dim))
|
||||||
|
local_output = torch.cat(local_output_list, 1)
|
||||||
|
# then concatenate those local_output on the second demension.
|
||||||
|
# use all_to_all
|
||||||
|
remains = batch_size % self.world_size
|
||||||
|
scatter_strides = [batch_size // self.world_size + int(i < remains) for i in range(self.world_size)]
|
||||||
|
output_full = dual_all_to_all_tablewise(local_output, self.pg, scatter_strides, self.embedding_dim_per_rank)
|
||||||
|
if shape_hook is not None:
|
||||||
|
output_full = shape_hook(output_full)
|
||||||
|
return output_full
|
||||||
|
|
||||||
|
def element_size(self):
|
||||||
|
if len(self.assigned_table_list) == 0:
|
||||||
|
return 0
|
||||||
|
return self.freq_aware_embedding_bag_list[0].cache_weight_mgr.weight.element_size()
|
||||||
|
|
||||||
|
def print_comm_stats_(self):
|
||||||
|
cuda_to_cpu_elem_num = 0
|
||||||
|
cpu_to_cuda_elem_num = 0
|
||||||
|
for freq_aware_embedding_bag in self.freq_aware_embedding_bag_list:
|
||||||
|
cuda_to_cpu_elem_num += freq_aware_embedding_bag.cache_weight_mgr._cuda_to_cpu_numel
|
||||||
|
cpu_to_cuda_elem_num += freq_aware_embedding_bag.cache_weight_mgr._cpu_to_cuda_numel
|
||||||
|
print(f"CUDA->CPU num: {cuda_to_cpu_elem_num / 1e6} M elem")
|
||||||
|
print(f"CPU->CUDA num: {cpu_to_cuda_elem_num / 1e6} M elem")
|
|
@ -13,7 +13,7 @@ from colossalai.testing import rerun_if_address_is_in_use
|
||||||
from colossalai.tensor import ColoParameter, ProcessGroup, ShardSpec, ComputePattern, ComputeSpec, \
|
from colossalai.tensor import ColoParameter, ProcessGroup, ShardSpec, ComputePattern, ComputeSpec, \
|
||||||
ColoTensor, ColoTensorSpec
|
ColoTensor, ColoTensorSpec
|
||||||
from colossalai.nn.parallel.layers import CachedParamMgr, FreqAwareEmbeddingBag, ParallelFreqAwareEmbeddingBag, EvictionStrategy, \
|
from colossalai.nn.parallel.layers import CachedParamMgr, FreqAwareEmbeddingBag, ParallelFreqAwareEmbeddingBag, EvictionStrategy, \
|
||||||
ParallelFreqAwareEmbeddingBagTablewise, TablewiseEmbeddingBagConfig, ParallelFreqAwareEmbeddingBagTablewiseSpiltCache
|
ParallelFreqAwareEmbeddingBagTablewise, TablewiseEmbeddingBagConfig
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
NUM_EMBED, EMBED_DIM = 10, 8
|
NUM_EMBED, EMBED_DIM = 10, 8
|
||||||
|
@ -209,19 +209,28 @@ def run_parallel_freq_aware_embed_tablewise(rank, world_size):
|
||||||
|
|
||||||
# initialize weight
|
# initialize weight
|
||||||
# 3 feature tables. idx: 0~5, 6~10, 11~17
|
# 3 feature tables. idx: 0~5, 6~10, 11~17
|
||||||
weight_tables = torch.rand(18,5)
|
weight_tables = torch.rand(18, 5)
|
||||||
weight_table1 = weight_tables[0:6]
|
weight_table1 = weight_tables[0:6]
|
||||||
weight_table2 = weight_tables[6:11]
|
weight_table2 = weight_tables[6:11]
|
||||||
weight_table3 = weight_tables[11:18]
|
weight_table3 = weight_tables[11:18]
|
||||||
embedding_bag_config_list: List[TablewiseEmbeddingBagConfig] = []
|
embedding_bag_config_list: List[TablewiseEmbeddingBagConfig] = []
|
||||||
embedding_bag_config_list.append(TablewiseEmbeddingBagConfig(
|
embedding_bag_config_list.append(
|
||||||
num_embeddings=6, cuda_row_num=4, assigned_rank=0, initial_weight=weight_table1.clone().detach().cpu()))
|
TablewiseEmbeddingBagConfig(num_embeddings=6,
|
||||||
embedding_bag_config_list.append(TablewiseEmbeddingBagConfig(
|
cuda_row_num=4,
|
||||||
num_embeddings=5, cuda_row_num=4, assigned_rank=0, initial_weight=weight_table2.clone().detach().cpu()))
|
assigned_rank=0,
|
||||||
embedding_bag_config_list.append(TablewiseEmbeddingBagConfig(
|
initial_weight=weight_table1.clone().detach().cpu()))
|
||||||
num_embeddings=7, cuda_row_num=4, assigned_rank=1, initial_weight=weight_table3.clone().detach().cpu()))
|
embedding_bag_config_list.append(
|
||||||
|
TablewiseEmbeddingBagConfig(num_embeddings=5,
|
||||||
|
cuda_row_num=4,
|
||||||
|
assigned_rank=0,
|
||||||
|
initial_weight=weight_table2.clone().detach().cpu()))
|
||||||
|
embedding_bag_config_list.append(
|
||||||
|
TablewiseEmbeddingBagConfig(num_embeddings=7,
|
||||||
|
cuda_row_num=4,
|
||||||
|
assigned_rank=1,
|
||||||
|
initial_weight=weight_table3.clone().detach().cpu()))
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
_weight = torch.cat([weight_table1, weight_table2],0)
|
_weight = torch.cat([weight_table1, weight_table2], 0)
|
||||||
else:
|
else:
|
||||||
_weight = weight_table3
|
_weight = weight_table3
|
||||||
model = ParallelFreqAwareEmbeddingBagTablewise(
|
model = ParallelFreqAwareEmbeddingBagTablewise(
|
||||||
|
@ -249,30 +258,31 @@ def run_parallel_freq_aware_embed_tablewise(rank, world_size):
|
||||||
rand_grad = torch.rand(3, 5 * 3, dtype=res.dtype, device=res.device)
|
rand_grad = torch.rand(3, 5 * 3, dtype=res.dtype, device=res.device)
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
fake_grad = rand_grad[0:2]
|
fake_grad = rand_grad[0:2]
|
||||||
else :
|
else:
|
||||||
fake_grad = rand_grad[2:]
|
fake_grad = rand_grad[2:]
|
||||||
res.backward(fake_grad)
|
res.backward(fake_grad)
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
# check correctness
|
# check correctness
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
ref_model = torch.nn.EmbeddingBag.from_pretrained(weight_tables.detach().clone(),
|
ref_model = torch.nn.EmbeddingBag.from_pretrained(weight_tables.detach().clone(),
|
||||||
include_last_offset=True,
|
include_last_offset=True,
|
||||||
freeze=False).to(device)
|
freeze=False).to(device)
|
||||||
ref_optimizer = torch.optim.SGD(ref_model.parameters(), lr=1e-2)
|
ref_optimizer = torch.optim.SGD(ref_model.parameters(), lr=1e-2)
|
||||||
ref_fake_grad = torch.cat(rand_grad.split(5,1),0)
|
ref_fake_grad = torch.cat(rand_grad.split(5, 1), 0)
|
||||||
ref_res = ref_model(torch.tensor([1, 2, 3, 1, 5, 6, 7, 9, 6, 8, 13, 15, 11], device=device),
|
ref_res = ref_model(torch.tensor([1, 2, 3, 1, 5, 6, 7, 9, 6, 8, 13, 15, 11], device=device),
|
||||||
torch.tensor([0, 3, 3, 5, 7, 8, 10, 10, 12, 13], device=device))
|
torch.tensor([0, 3, 3, 5, 7, 8, 10, 10, 12, 13], device=device))
|
||||||
ref_res.backward(ref_fake_grad)
|
ref_res.backward(ref_fake_grad)
|
||||||
ref_optimizer.step()
|
ref_optimizer.step()
|
||||||
ref_optimizer.zero_grad()
|
ref_optimizer.zero_grad()
|
||||||
|
|
||||||
model.cache_weight_mgr.flush()
|
model.cache_weight_mgr.flush()
|
||||||
recover_weight = model.cache_weight_mgr.weight.to(device)
|
recover_weight = model.cache_weight_mgr.weight.to(device)
|
||||||
ref_weight = ref_model.weight.detach()[:11]
|
ref_weight = ref_model.weight.detach()[:11]
|
||||||
assert torch.allclose(recover_weight, ref_weight), f"{recover_weight - ref_weight}"
|
assert torch.allclose(recover_weight, ref_weight), f"{recover_weight - ref_weight}"
|
||||||
|
|
||||||
|
|
||||||
def run_parallel_freq_aware_embed_columnwise(rank, world_size):
|
def run_parallel_freq_aware_embed_columnwise(rank, world_size):
|
||||||
device = torch.device('cuda', torch.cuda.current_device())
|
device = torch.device('cuda', torch.cuda.current_device())
|
||||||
|
|
||||||
|
@ -289,11 +299,12 @@ def run_parallel_freq_aware_embed_columnwise(rank, world_size):
|
||||||
coloweight.set_process_group(ProcessGroup(tp_degree=world_size))
|
coloweight.set_process_group(ProcessGroup(tp_degree=world_size))
|
||||||
coloweight.set_tensor_spec(ShardSpec(dims=[-1], num_partitions=[world_size]), ComputeSpec(ComputePattern.TP1D))
|
coloweight.set_tensor_spec(ShardSpec(dims=[-1], num_partitions=[world_size]), ComputeSpec(ComputePattern.TP1D))
|
||||||
|
|
||||||
model = ParallelFreqAwareEmbeddingBag.from_pretrained(coloweight,
|
model = ParallelFreqAwareEmbeddingBag.from_pretrained(
|
||||||
include_last_offset=True,
|
coloweight,
|
||||||
freeze=False,
|
include_last_offset=True,
|
||||||
cuda_row_num=batch_size * 2,
|
freeze=False,
|
||||||
)
|
cuda_row_num=batch_size * 2,
|
||||||
|
)
|
||||||
|
|
||||||
assert model.cache_weight_mgr.weight.device.type == 'cpu'
|
assert model.cache_weight_mgr.weight.device.type == 'cpu'
|
||||||
assert model.cache_weight_mgr.cuda_cached_weight.requires_grad
|
assert model.cache_weight_mgr.cuda_cached_weight.requires_grad
|
||||||
|
|
Loading…
Reference in New Issue