[embedding] freq_aware_embedding: add small functions for caller application (#1537)

pull/1541/head
CsRic 2022-09-05 15:12:53 +08:00 committed by GitHub
parent 70129603aa
commit 964123ae0f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 214 additions and 46 deletions

View File

@ -4,10 +4,11 @@ from .embedding import ColoEmbedding
from .module_utils import register_colo_module, is_colo_module, get_colo_module, init_colo_module, check_colo_module
from .cache_embedding import FreqAwareEmbeddingBag, ParallelFreqAwareEmbeddingBag, CachedParamMgr, LimitBuffIndexCopyer, EvictionStrategy, \
ParallelFreqAwareEmbeddingBagTablewise, TablewiseEmbeddingBagConfig
ParallelFreqAwareEmbeddingBagTablewise, TablewiseEmbeddingBagConfig, ParallelFreqAwareEmbeddingBagTablewiseSpiltCache
__all__ = [
'ColoModule', 'register_colo_module', 'is_colo_module', 'get_colo_module', 'init_colo_module', 'check_colo_module',
'ColoLinear', 'ColoEmbedding', 'FreqAwareEmbeddingBag', 'ParallelFreqAwareEmbeddingBag', 'CachedParamMgr',
'LimitBuffIndexCopyer', 'EvictionStrategy', 'ParallelFreqAwareEmbeddingBagTablewise', 'TablewiseEmbeddingBagConfig'
'LimitBuffIndexCopyer', 'EvictionStrategy', 'ParallelFreqAwareEmbeddingBagTablewise', 'TablewiseEmbeddingBagConfig',
'ParallelFreqAwareEmbeddingBagTablewiseSpiltCache'
]

View File

@ -2,8 +2,8 @@ from .cache_mgr import CachedParamMgr, EvictionStrategy
from .copyer import LimitBuffIndexCopyer
from .freq_aware_embedding import FreqAwareEmbeddingBag
from .parallel_freq_aware_embedding import ParallelFreqAwareEmbeddingBag
from .parallel_freq_aware_embedding_tablewise import ParallelFreqAwareEmbeddingBagTablewise, TablewiseEmbeddingBagConfig
from .parallel_freq_aware_embedding_tablewise import ParallelFreqAwareEmbeddingBagTablewise, TablewiseEmbeddingBagConfig, ParallelFreqAwareEmbeddingBagTablewiseSpiltCache
__all__ = [
'CachedParamMgr', 'LimitBuffIndexCopyer', 'FreqAwareEmbeddingBag', 'ParallelFreqAwareEmbeddingBag',
'EvictionStrategy', 'ParallelFreqAwareEmbeddingBagTablewise', 'TablewiseEmbeddingBagConfig'
'EvictionStrategy', 'ParallelFreqAwareEmbeddingBagTablewise', 'TablewiseEmbeddingBagConfig', 'ParallelFreqAwareEmbeddingBagTablewiseSpiltCache'
]

View File

@ -121,3 +121,9 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag):
buffer_size=buffer_size)
embedding_bag.cache_weight_mgr.cuda_cached_weight.requires_grad_ = not freeze
return embedding_bag
def print_comm_stats_(self):
self.cache_weight_mgr.print_comm_stats()
def element_size(self):
return self.weight.element_size()

View File

@ -1,9 +1,10 @@
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.profiler import record_function
from typing import List
import abc
import torch.nn.functional as F
from .freq_aware_embedding import FreqAwareEmbeddingBag
from colossalai.tensor import ProcessGroup
@ -38,7 +39,137 @@ class TablewiseEmbeddingBagConfig:
self.name = name
class ParallelFreqAwareEmbeddingBagTablewise(abc.ABC, nn.Module):
class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
"""
all tables assigned to this class instance are managed by a single FreqAwareEmbeddingBag.
Those parameters in TablewiseEmbeddingBagConfig are ignored: cuda_row_num, buffer_size, initial_weight.
"""
def __init__(self,
embedding_bag_config_list: List[TablewiseEmbeddingBagConfig],
embedding_dim: int,
padding_idx=None,
max_norm=None,
norm_type=2.,
scale_grad_by_freq=False,
sparse=False,
_weight=None,
mode='mean',
include_last_offset=False,
dtype=None,
device=None,
cuda_row_num=0,
warmup_ratio=0.7,
buffer_size=50_000,
pin_weight=False,
evict_strategy: EvictionStrategy = EvictionStrategy.LFU):
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
self.rank_of_tables = [config.assigned_rank for config in embedding_bag_config_list]
self.global_table_num_embeddings_list = [config.num_embeddings for config in embedding_bag_config_list]
self.global_tables_num = len(embedding_bag_config_list)
self.global_tables_offsets = torch.cumsum(torch.tensor([0] + self.global_table_num_embeddings_list), 0).cuda()
self.assigned_table_list: List[int] = []
self.pg = ProcessGroup(tp_degree=self.world_size)
self.num_embeddings = 0
for i, rank in enumerate(self.rank_of_tables):
if rank == self.rank:
self.assigned_table_list.append(i)
self.num_embeddings += self.global_table_num_embeddings_list[i]
self.include_last_offset = include_last_offset
ids_freq_mapping = []
for config in embedding_bag_config_list:
if config.assigned_rank == self.rank:
if config.ids_freq_mapping != None:
ids_freq_mapping.extend(config.ids_freq_mapping)
else:
ids_freq_mapping = None
break
# table-associate cache
super(ParallelFreqAwareEmbeddingBagTablewise,
self).__init__(self.num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq,
sparse, _weight, mode, include_last_offset, dtype, device, cuda_row_num, ids_freq_mapping,
warmup_ratio, buffer_size, pin_weight, evict_strategy)
# for assigned tables reconnection:
self.idx_offset_list = []
offset_cumsum = 0
for table_i, table_num_embeddings in enumerate(self.global_table_num_embeddings_list):
if self.rank_of_tables[table_i] == self.rank:
self.idx_offset_list.append(offset_cumsum)
else :
offset_cumsum += table_num_embeddings
# prepare list shape for all_to_all output
self.embedding_dim_per_rank = [0 for i in range(self.world_size)]
for rank in self.rank_of_tables:
self.embedding_dim_per_rank[rank] += embedding_dim
def forward(self, indices: torch.Tensor, offsets: torch.Tensor = None, per_sample_weights=None, shape_hook=None):
batch_size = (offsets.shape[0]) // self.global_tables_num
local_indices_list: List(torch.Tensor) = []
local_offsets_list: List(torch.Tensor) = []
if per_sample_weights != None:
local_per_sample_weights_list: List(torch.Tensor) = []
offset_pre_end = 0 # local_offsets trick
for i, handle_table in enumerate(self.assigned_table_list):
indices_start_position = offsets[batch_size * handle_table]
if (not self.include_last_offset) and (batch_size * (handle_table + 1) >= indices.shape[0]):
# till-the-end special case
indices_end_position = indices.shape[0]
else:
indices_end_position = offsets[batch_size * (handle_table + 1)]
# 1. local_indices_list:
local_indices_list.append(indices.narrow(0, indices_start_position, indices_end_position
- indices_start_position).sub(self.idx_offset_list[i]))
# 2. local_offsets_list:
if i + 1 == len(self.assigned_table_list):
# till-the-end special case
if not self.include_last_offset:
local_offsets = offsets.narrow(0, batch_size * handle_table,
batch_size).add(offset_pre_end - offsets[batch_size * (handle_table)])
else :
local_offsets = offsets.narrow(0, batch_size * handle_table,
batch_size + 1).add(offset_pre_end - offsets[batch_size * (handle_table)])
local_offsets_list.append(local_offsets)
else:
local_offsets = offsets.narrow(0, batch_size * handle_table,
batch_size + 1).add(offset_pre_end - offsets[batch_size * (handle_table)])
offset_pre_end = local_offsets[-1]
local_offsets_list.append(local_offsets[:-1])
# 3. local_per_sample_weights_list:
if per_sample_weights != None:
local_per_sample_weights_list.append(per_sample_weights[indices_start_position:indices_end_position])
local_indices = torch.cat(local_indices_list, 0)
local_offsets = torch.cat(local_offsets_list, 0)
local_per_sample_weights = None
if per_sample_weights != None:
local_per_sample_weights = torch.cat(local_per_sample_weights_list, 0)
with torch.no_grad():
reorder_ids = self.cache_weight_mgr.prepare_ids(local_indices)
local_output = F.embedding_bag(reorder_ids.cuda(), self.cache_weight_mgr.cuda_cached_weight, local_offsets,
self.max_norm, self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse,
local_per_sample_weights, self.include_last_offset, self.padding_idx)
local_output = torch.cat(local_output.split(batch_size),1)
remains = batch_size % self.world_size
scatter_strides = [batch_size // self.world_size + int(i < remains) for i in range(self.world_size)]
output_full = dual_all_to_all_tablewise(local_output, self.pg, scatter_strides, self.embedding_dim_per_rank)
if shape_hook is not None:
output_full = shape_hook(output_full)
return output_full
def print_comm_stats_(self):
self.cache_weight_mgr.print_comm_stats()
def element_size(self):
return self.weight.element_size()
class ParallelFreqAwareEmbeddingBagTablewiseSpiltCache(abc.ABC, nn.Module):
"""
every table assigned to this class instance is managed by a FreqAwareEmbeddingBag.
"""
@ -58,7 +189,7 @@ class ParallelFreqAwareEmbeddingBagTablewise(abc.ABC, nn.Module):
warmup_ratio=0.7,
pin_weight=False,
evict_strategy: EvictionStrategy = EvictionStrategy.LFU):
super(ParallelFreqAwareEmbeddingBagTablewise, self).__init__()
super(ParallelFreqAwareEmbeddingBagTablewiseSpiltCache, self).__init__()
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
self.rank_of_tables = [config.assigned_rank for config in embedding_bag_config_list]
@ -109,24 +240,30 @@ class ParallelFreqAwareEmbeddingBagTablewise(abc.ABC, nn.Module):
batch_size = (offsets.shape[0]) // self.global_tables_num
local_output_list = []
for i, handle_table in enumerate(self.assigned_table_list):
with record_function("(tablewise) prepare indices and offsets"):
with record_function("part 1"):
indices_start_position = offsets[batch_size * handle_table]
if (not self.include_last_offset) and (batch_size * (handle_table + 1) >= indices.shape[0]):
# till the end special case
indices_end_position = indices.shape[0]
else:
indices_end_position = offsets[batch_size * (handle_table + 1)]
local_indices = indices[indices_start_position:indices_end_position] - \
self.global_tables_offsets[handle_table]
with record_function("part 2"):
# local_indices = indices[indices_start_position:indices_end_position] - self.global_tables_offsets[handle_table]
local_indices = indices.narrow(0, indices_start_position, indices_end_position
- indices_start_position).sub(self.global_tables_offsets[handle_table])
if self.include_last_offset:
local_offsets = offsets[batch_size * handle_table:batch_size *
(handle_table + 1) + 1] - offsets[batch_size * (handle_table)]
# local_offsets = offsets[batch_size * handle_table:batch_size * (handle_table + 1) + 1] - offsets[batch_size * (handle_table)]
local_offsets = offsets.narrow(0, batch_size * handle_table,
batch_size + 1).sub(offsets[batch_size * (handle_table)])
else:
local_offsets = offsets[batch_size * handle_table:batch_size *
(handle_table + 1)] - offsets[batch_size * (handle_table)]
# local_offsets = offsets[batch_size * handle_table:batch_size * (handle_table + 1)] - offsets[batch_size * (handle_table)]
local_offsets = offsets.narrow(0, batch_size * handle_table,
batch_size).sub(offsets[batch_size * (handle_table)])
local_per_sample_weights = None
if per_sample_weights != None:
local_per_sample_weights = per_sample_weights[indices_start_position:indices_end_position]
with record_function("(tablewise) tablewise forward"):
local_output_list.append(self.freq_aware_embedding_bag_list[i](local_indices, local_offsets,
local_per_sample_weights))
@ -140,3 +277,21 @@ class ParallelFreqAwareEmbeddingBagTablewise(abc.ABC, nn.Module):
if shape_hook is not None:
output_full = shape_hook(output_full)
return output_full
def element_size(self):
if len(self.assigned_table_list) == 0:
return 0
return self.freq_aware_embedding_bag_list[0].cache_weight_mgr.weight.element_size()
def print_comm_stats_(self):
cuda_to_cpu_elem_num = 0
cpu_to_cuda_elem_num = 0
for freq_aware_embedding_bag in self.freq_aware_embedding_bag_list:
cuda_to_cpu_elem_num += freq_aware_embedding_bag.cache_weight_mgr._cuda_to_cpu_numel
cpu_to_cuda_elem_num += freq_aware_embedding_bag.cache_weight_mgr._cpu_to_cuda_numel
print(
f"CUDA->CPU num: {cuda_to_cpu_elem_num / 1e6} M elem"
)
print(
f"CPU->CUDA num: {cpu_to_cuda_elem_num / 1e6} M elem"
)

View File

@ -13,7 +13,7 @@ from colossalai.testing import rerun_if_address_is_in_use
from colossalai.tensor import ColoParameter, ProcessGroup, ShardSpec, ComputePattern, ComputeSpec, \
ColoTensor, ColoTensorSpec
from colossalai.nn.parallel.layers import CachedParamMgr, FreqAwareEmbeddingBag, ParallelFreqAwareEmbeddingBag, EvictionStrategy, \
ParallelFreqAwareEmbeddingBagTablewise, TablewiseEmbeddingBagConfig
ParallelFreqAwareEmbeddingBagTablewise, TablewiseEmbeddingBagConfig, ParallelFreqAwareEmbeddingBagTablewiseSpiltCache
from typing import List
NUM_EMBED, EMBED_DIM = 10, 8
@ -209,9 +209,10 @@ def run_parallel_freq_aware_embed_tablewise(rank, world_size):
# initialize weight
# 3 feature tables. idx: 0~5, 6~10, 11~17
weight_table1 = torch.rand(6, 5)
weight_table2 = torch.rand(5, 5)
weight_table3 = torch.rand(7, 5)
weight_tables = torch.rand(18,5)
weight_table1 = weight_tables[0:6]
weight_table2 = weight_tables[6:11]
weight_table3 = weight_tables[11:18]
embedding_bag_config_list: List[TablewiseEmbeddingBagConfig] = []
embedding_bag_config_list.append(TablewiseEmbeddingBagConfig(
num_embeddings=6, cuda_row_num=4, assigned_rank=0, initial_weight=weight_table1.clone().detach().cpu()))
@ -219,14 +220,20 @@ def run_parallel_freq_aware_embed_tablewise(rank, world_size):
num_embeddings=5, cuda_row_num=4, assigned_rank=0, initial_weight=weight_table2.clone().detach().cpu()))
embedding_bag_config_list.append(TablewiseEmbeddingBagConfig(
num_embeddings=7, cuda_row_num=4, assigned_rank=1, initial_weight=weight_table3.clone().detach().cpu()))
if rank == 0:
_weight = torch.cat([weight_table1, weight_table2],0)
else:
_weight = weight_table3
model = ParallelFreqAwareEmbeddingBagTablewise(
embedding_bag_config_list,
embedding_dim=5,
_weight=_weight,
include_last_offset=True,
cuda_row_num=8,
buffer_size=0,
evict_strategy=EvictionStrategy.LFU,
include_last_offset=True
)
# demo explain:
# explain
'''
batch feature 1 feature 2 feature 3
input0 [1,2,3] [6,7] []
@ -244,28 +251,27 @@ def run_parallel_freq_aware_embed_tablewise(rank, world_size):
fake_grad = rand_grad[0:2]
else :
fake_grad = rand_grad[2:]
res.backward(fake_grad)
optimizer.step()
optimizer.zero_grad()
# check correctness on weight_table2
# check correctness
if rank == 0:
ref_model = torch.nn.EmbeddingBag.from_pretrained(weight_table2.detach().clone(),
ref_model = torch.nn.EmbeddingBag.from_pretrained(weight_tables.detach().clone(),
include_last_offset=True,
freeze=False).to(device)
ref_optimizer = torch.optim.SGD(ref_model.parameters(), lr=1e-2)
ref_grad = rand_grad[:, 5:10]
ref_res = ref_model(torch.tensor([0, 1, 3, 0, 2], device=device), torch.tensor([0, 2, 3, 5], device=device))
ref_res.backward(ref_grad)
ref_fake_grad = torch.cat(rand_grad.split(5,1),0)
ref_res = ref_model(torch.tensor([1, 2, 3, 1, 5, 6, 7, 9, 6, 8, 13, 15, 11], device=device),
torch.tensor([0, 3, 3, 5, 7, 8, 10, 10, 12, 13], device=device))
ref_res.backward(ref_fake_grad)
ref_optimizer.step()
ref_optimizer.zero_grad()
model.freq_aware_embedding_bag_list[1].cache_weight_mgr.flush() # update cpu weight
recover_weight = model.freq_aware_embedding_bag_list[1].cache_weight_mgr.weight
assert torch.allclose(recover_weight, ref_model.weight.detach().cpu()
), f"{recover_weight - ref_model.weight.detach().cpu()}"
model.cache_weight_mgr.flush()
recover_weight = model.cache_weight_mgr.weight.to(device)
ref_weight = ref_model.weight.detach()[:11]
assert torch.allclose(recover_weight, ref_weight), f"{recover_weight - ref_weight}"
def run_parallel_freq_aware_embed_columnwise(rank, world_size):
device = torch.device('cuda', torch.cuda.current_device())