mirror of https://github.com/hpcaitech/ColossalAI
[FAW] reorganize the inheritance struct of FreqCacheEmbedding (#1448)
parent
5a52e21fe3
commit
9f3eed66eb
|
@ -23,48 +23,61 @@ class CachedParamMgr(torch.nn.Module):
|
||||||
|
|
||||||
self.elem_size_in_byte = weight.element_size()
|
self.elem_size_in_byte = weight.element_size()
|
||||||
|
|
||||||
self.cuda_cached_weight = torch.nn.Parameter(
|
# weight configure
|
||||||
torch.zeros(self.cuda_row_num, self.embedding_dim, device=torch.cuda.current_device(), dtype=weight.dtype))
|
self._init_weight(weight)
|
||||||
|
|
||||||
if weight.device.type == 'cuda':
|
|
||||||
weight = weight.cpu()
|
|
||||||
|
|
||||||
# pin memory cpu for higher CPU-GPU copy bandwidth
|
|
||||||
self.cpu_weight = weight.contiguous().pin_memory()
|
|
||||||
|
|
||||||
# map original id to new id with respect to frequency
|
|
||||||
# id -> cpu_row_idx
|
|
||||||
self.register_buffer(
|
|
||||||
"idx_map",
|
|
||||||
torch.arange(self.num_embeddings, dtype=torch.long, device=torch.cuda.current_device()),
|
|
||||||
persistent=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# cached_idx_map: gpu_row_idx -> cpu_row_idx
|
|
||||||
self.register_buffer("cached_idx_map",
|
|
||||||
torch.empty(self.cuda_row_num, device=torch.cuda.current_device(),
|
|
||||||
dtype=torch.long).fill_(-1),
|
|
||||||
persistent=False)
|
|
||||||
|
|
||||||
# cpu_row_id -> gpu_row_idx.
|
|
||||||
# gpu_row_idx as -1 means cpu_row_id not in CUDA.
|
|
||||||
self.register_buffer("inverted_cached_idx",
|
|
||||||
torch.zeros(self.num_embeddings, device=torch.cuda.current_device(),
|
|
||||||
dtype=torch.long).fill_(-1),
|
|
||||||
persistent=False)
|
|
||||||
|
|
||||||
self.evict_backlist = torch.tensor([], device=torch.cuda.current_device())
|
|
||||||
|
|
||||||
# index copy buffer size should less than 10% of cuda weight.
|
|
||||||
if self.buffer_size > 0:
|
|
||||||
self.limit_buff_index_copyer = LimitBuffIndexCopyer(self.buffer_size)
|
|
||||||
|
|
||||||
|
# Perf log
|
||||||
self.num_hits_history = []
|
self.num_hits_history = []
|
||||||
self.num_miss_history = []
|
self.num_miss_history = []
|
||||||
self.num_write_back_history = []
|
self.num_write_back_history = []
|
||||||
self.input_id_percent_in_load_chunk = []
|
self.input_id_percent_in_load_chunk = []
|
||||||
self._reset_comm_stats()
|
self._reset_comm_stats()
|
||||||
|
|
||||||
|
def _init_weight(self, weight):
|
||||||
|
if self.cuda_row_num > 0:
|
||||||
|
# Enable cache with introducing auxiliary data structures
|
||||||
|
self.cuda_cached_weight = torch.nn.Parameter(
|
||||||
|
torch.zeros(self.cuda_row_num,
|
||||||
|
self.embedding_dim,
|
||||||
|
device=torch.cuda.current_device(),
|
||||||
|
dtype=weight.dtype))
|
||||||
|
|
||||||
|
# pin memory cpu for higher CPU-GPU copy bandwidth
|
||||||
|
self.weight = weight.contiguous().cpu().pin_memory()
|
||||||
|
|
||||||
|
# map original id to new id with respect to frequency
|
||||||
|
# id -> cpu_row_idx
|
||||||
|
self.register_buffer(
|
||||||
|
"idx_map",
|
||||||
|
torch.arange(self.num_embeddings, dtype=torch.long, device=torch.cuda.current_device()),
|
||||||
|
persistent=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# cached_idx_map: gpu_row_idx -> cpu_row_idx
|
||||||
|
self.register_buffer("cached_idx_map",
|
||||||
|
torch.empty(self.cuda_row_num, device=torch.cuda.current_device(),
|
||||||
|
dtype=torch.long).fill_(-1),
|
||||||
|
persistent=False)
|
||||||
|
|
||||||
|
# cpu_row_id -> gpu_row_idx.
|
||||||
|
# gpu_row_idx as -1 means cpu_row_id not in CUDA.
|
||||||
|
self.register_buffer("inverted_cached_idx",
|
||||||
|
torch.zeros(self.num_embeddings, device=torch.cuda.current_device(),
|
||||||
|
dtype=torch.long).fill_(-1),
|
||||||
|
persistent=False)
|
||||||
|
|
||||||
|
self.evict_backlist = torch.tensor([], device=torch.cuda.current_device())
|
||||||
|
|
||||||
|
# index copy buffer size should less than 10% of cuda weight.
|
||||||
|
if self.buffer_size > 0:
|
||||||
|
self.limit_buff_index_copyer = LimitBuffIndexCopyer(self.buffer_size)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Disable cache so that FreqCacheEmbedding is compatible with vanilla EmbeddingBag
|
||||||
|
# self.weight = torch.nn.Parameter(weight)
|
||||||
|
# self.cuda_cached_weight = self.weight
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
def cpu_weight_data(self, chunk_id: int) -> torch.Tensor:
|
def cpu_weight_data(self, chunk_id: int) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
access a chunk of CPU weight.
|
access a chunk of CPU weight.
|
||||||
|
@ -76,9 +89,9 @@ class CachedParamMgr(torch.nn.Module):
|
||||||
torch.Tensor: a piece of memory in CPU weight corresponding to chunk id's payload. The tensor is 1-D.
|
torch.Tensor: a piece of memory in CPU weight corresponding to chunk id's payload. The tensor is 1-D.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return self.cpu_weight.data.view(-1).narrow(0,
|
return self.weight.data.view(-1).narrow(0,
|
||||||
int(chunk_id) * self.embedding_dim,
|
int(chunk_id) * self.embedding_dim,
|
||||||
self.embedding_dim).view(1, self.embedding_dim)
|
self.embedding_dim).view(1, self.embedding_dim)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def cuda_available_chunk_num(self):
|
def cuda_available_chunk_num(self):
|
||||||
|
@ -86,7 +99,7 @@ class CachedParamMgr(torch.nn.Module):
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def reorder(self, ids_freq_mapping: Optional[List[int]] = None, warmup_ratio=0.7):
|
def reorder(self, ids_freq_mapping: Optional[List[int]] = None, warmup_ratio=0.7):
|
||||||
"""reorder the cpu_weight according to ids' frequency in dataset before training.
|
"""reorder the weight according to ids' frequency in dataset before training.
|
||||||
Also Build the IndexMappingTable, aka index_mapping_table.
|
Also Build the IndexMappingTable, aka index_mapping_table.
|
||||||
Execute only once before training.
|
Execute only once before training.
|
||||||
Args:
|
Args:
|
||||||
|
@ -112,11 +125,10 @@ class CachedParamMgr(torch.nn.Module):
|
||||||
self.limit_buff_index_copyer.index_copy(0,
|
self.limit_buff_index_copyer.index_copy(0,
|
||||||
src_index=preload_row_ids,
|
src_index=preload_row_ids,
|
||||||
tgt_index=preload_slot_ids,
|
tgt_index=preload_slot_ids,
|
||||||
src=self.cpu_weight.view(self.num_embeddings, -1),
|
src=self.weight.view(self.num_embeddings, -1),
|
||||||
tgt=self.cuda_cached_weight.view(self.cuda_row_num, -1))
|
tgt=self.cuda_cached_weight.view(self.cuda_row_num, -1))
|
||||||
else:
|
else:
|
||||||
preload_chunks = self.cpu_weight.view(self.num_embeddings, -1).index_select(0,
|
preload_chunks = self.weight.view(self.num_embeddings, -1).index_select(0, preload_row_ids).cuda()
|
||||||
preload_row_ids).cuda()
|
|
||||||
self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(0, preload_slot_ids, preload_chunks)
|
self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(0, preload_slot_ids, preload_chunks)
|
||||||
|
|
||||||
# update auxiliary info
|
# update auxiliary info
|
||||||
|
@ -133,7 +145,7 @@ class CachedParamMgr(torch.nn.Module):
|
||||||
slots = torch.nonzero(self.cached_idx_map > -1).squeeze(1)
|
slots = torch.nonzero(self.cached_idx_map > -1).squeeze(1)
|
||||||
chunk_ids = self.cached_idx_map[slots]
|
chunk_ids = self.cached_idx_map[slots]
|
||||||
chunks = self.cuda_cached_weight.view(self.cuda_row_num, -1).index_select(0, slots).cpu()
|
chunks = self.cuda_cached_weight.view(self.cuda_row_num, -1).index_select(0, slots).cpu()
|
||||||
self.cpu_weight.view(self.num_embeddings, -1).index_copy_(0, chunk_ids.cpu(), chunks)
|
self.weight.view(self.num_embeddings, -1).index_copy_(0, chunk_ids.cpu(), chunks)
|
||||||
self.cached_idx_map.index_fill_(0, slots, -1)
|
self.cached_idx_map.index_fill_(0, slots, -1)
|
||||||
self.inverted_cached_idx.index_fill_(0, chunk_ids, -1)
|
self.inverted_cached_idx.index_fill_(0, chunk_ids, -1)
|
||||||
self._cuda_available_row_num += slots.numel()
|
self._cuda_available_row_num += slots.numel()
|
||||||
|
@ -237,11 +249,11 @@ class CachedParamMgr(torch.nn.Module):
|
||||||
src_index=evict_gpu_row_idxs,
|
src_index=evict_gpu_row_idxs,
|
||||||
tgt_index=evict_info.cpu(),
|
tgt_index=evict_info.cpu(),
|
||||||
src=self.cuda_cached_weight.view(self.cuda_row_num, -1),
|
src=self.cuda_cached_weight.view(self.cuda_row_num, -1),
|
||||||
tgt=self.cpu_weight.view(self.num_embeddings, -1))
|
tgt=self.weight.view(self.num_embeddings, -1))
|
||||||
else:
|
else:
|
||||||
# allocate tmp memory on CPU and copy rows on CUDA to CPU.
|
# allocate tmp memory on CPU and copy rows on CUDA to CPU.
|
||||||
rows = self.cuda_cached_weight.view(self.cuda_row_num, -1).index_select(0, evict_gpu_row_idxs).cpu()
|
rows = self.cuda_cached_weight.view(self.cuda_row_num, -1).index_select(0, evict_gpu_row_idxs).cpu()
|
||||||
self.cpu_weight.view(self.num_embeddings, -1).index_copy_(0, evict_info.cpu(), rows)
|
self.weight.view(self.num_embeddings, -1).index_copy_(0, evict_info.cpu(), rows)
|
||||||
|
|
||||||
self.cached_idx_map.index_fill_(0, evict_gpu_row_idxs, -1)
|
self.cached_idx_map.index_fill_(0, evict_gpu_row_idxs, -1)
|
||||||
self.inverted_cached_idx.index_fill_(0, evict_info, -1)
|
self.inverted_cached_idx.index_fill_(0, evict_info, -1)
|
||||||
|
@ -259,10 +271,10 @@ class CachedParamMgr(torch.nn.Module):
|
||||||
self.limit_buff_index_copyer.index_copy(0,
|
self.limit_buff_index_copyer.index_copy(0,
|
||||||
src_index=cpu_row_idxs.cpu(),
|
src_index=cpu_row_idxs.cpu(),
|
||||||
tgt_index=slots,
|
tgt_index=slots,
|
||||||
src=self.cpu_weight.view(self.num_embeddings, -1),
|
src=self.weight.view(self.num_embeddings, -1),
|
||||||
tgt=self.cuda_cached_weight.view(self.cuda_row_num, -1))
|
tgt=self.cuda_cached_weight.view(self.cuda_row_num, -1))
|
||||||
else:
|
else:
|
||||||
rows = self.cpu_weight.view(self.num_embeddings, -1).index_select(0, cpu_row_idxs.cpu()).cuda()
|
rows = self.weight.view(self.num_embeddings, -1).index_select(0, cpu_row_idxs.cpu()).cuda()
|
||||||
self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(0, slots, rows)
|
self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(0, slots, rows)
|
||||||
slot_offsets = slots
|
slot_offsets = slots
|
||||||
self.cached_idx_map[slots] = cpu_row_idxs
|
self.cached_idx_map[slots] = cpu_row_idxs
|
||||||
|
|
|
@ -9,15 +9,50 @@ from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
class FreqAwareEmbeddingBag(BaseEmbeddingBag):
|
class FreqAwareEmbeddingBag(BaseEmbeddingBag):
|
||||||
|
|
||||||
def __init__(self, num_embeddings, embedding_dim, dtype=None, *args, **kwargs):
|
def __init__(
|
||||||
super(FreqAwareEmbeddingBag, self).__init__(num_embeddings, embedding_dim, *args, **kwargs)
|
self,
|
||||||
self._weight = torch.randn(self.num_embeddings, self.embedding_dim, device='cpu', dtype=dtype)
|
num_embeddings,
|
||||||
|
embedding_dim,
|
||||||
|
padding_idx=None,
|
||||||
|
max_norm=None,
|
||||||
|
norm_type=2.,
|
||||||
|
scale_grad_by_freq=False,
|
||||||
|
sparse=False,
|
||||||
|
_weight=None,
|
||||||
|
mode='mean',
|
||||||
|
include_last_offset=False,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
cuda_row_num=0,
|
||||||
|
ids_freq_mapping=None,
|
||||||
|
warmup_ratio=0.7,
|
||||||
|
buffer_size=50_000,
|
||||||
|
):
|
||||||
|
super(FreqAwareEmbeddingBag, self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type,
|
||||||
|
scale_grad_by_freq, sparse, mode, include_last_offset)
|
||||||
|
|
||||||
def preprocess(self,
|
if _weight is None:
|
||||||
cuda_row_num: int,
|
_weight = self._weight_alloc(dtype, device)
|
||||||
ids_freq_mapping: Optional[List[int]] = None,
|
else:
|
||||||
warmup_ratio=0.7,
|
_weight = _weight
|
||||||
buffer_size=50_000):
|
|
||||||
|
# configure weight & cache
|
||||||
|
self._preprocess(_weight, cuda_row_num, ids_freq_mapping, warmup_ratio, buffer_size)
|
||||||
|
|
||||||
|
def _weight_alloc(self, dtype, device):
|
||||||
|
weight = torch.empty(self.num_embeddings, self.embedding_dim, dtype=dtype, device=device, pin_memory=True)
|
||||||
|
with torch.no_grad():
|
||||||
|
weight.data.uniform_(-1 / self.num_embeddings, 1 / self.num_embeddings)
|
||||||
|
if self.padding_idx is not None:
|
||||||
|
weight[self.padding_idx].fill_(0)
|
||||||
|
return weight
|
||||||
|
|
||||||
|
def _preprocess(self,
|
||||||
|
weight,
|
||||||
|
cuda_row_num: int,
|
||||||
|
ids_freq_mapping: Optional[List[int]] = None,
|
||||||
|
warmup_ratio=0.7,
|
||||||
|
buffer_size=50_000):
|
||||||
"""
|
"""
|
||||||
Called after initialized.
|
Called after initialized.
|
||||||
Reorder the weight rows according to the ids_freq_mapping.
|
Reorder the weight rows according to the ids_freq_mapping.
|
||||||
|
@ -27,7 +62,7 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag):
|
||||||
ids_freq_mapping (List[int]): a list, idx is id number, value is freq
|
ids_freq_mapping (List[int]): a list, idx is id number, value is freq
|
||||||
warmup_ratio (float): the amount of rows preloaded in cuda cache
|
warmup_ratio (float): the amount of rows preloaded in cuda cache
|
||||||
"""
|
"""
|
||||||
self.cache_weight_mgr = CachedParamMgr(self._weight, cuda_row_num, buffer_size)
|
self.cache_weight_mgr = CachedParamMgr(weight, cuda_row_num, buffer_size)
|
||||||
self.cache_weight_mgr.reorder(ids_freq_mapping, warmup_ratio)
|
self.cache_weight_mgr.reorder(ids_freq_mapping, warmup_ratio)
|
||||||
|
|
||||||
def forward(self, indices, offsets=None, per_sample_weights=None):
|
def forward(self, indices, offsets=None, per_sample_weights=None):
|
||||||
|
@ -42,8 +77,7 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def weight(self):
|
def weight(self):
|
||||||
assert self.cache_weight_mgr is not None
|
return self.cache_weight_mgr.weight
|
||||||
return self.cache_weight_mgr.cpu_weight.narrow(0, 0, self.num_embeddings)
|
|
||||||
|
|
||||||
def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, Parameter]]:
|
def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, Parameter]]:
|
||||||
yield 'weight', self.cache_weight_mgr.cuda_cached_weight
|
yield 'weight', self.cache_weight_mgr.cuda_cached_weight
|
||||||
|
@ -51,6 +85,9 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag):
|
||||||
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
|
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
|
||||||
yield self.cache_weight_mgr.cuda_cached_weight
|
yield self.cache_weight_mgr.cuda_cached_weight
|
||||||
|
|
||||||
|
|
||||||
|
############################# Perf Log ###################################
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_hits_history(self):
|
def num_hits_history(self):
|
||||||
return self.cache_weight_mgr.num_hits_history
|
return self.cache_weight_mgr.num_hits_history
|
||||||
|
|
|
@ -2,12 +2,12 @@ import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from typing import List, Optional, Iterator, Tuple
|
from typing import List, Optional, Iterator, Tuple
|
||||||
|
|
||||||
from .base_embedding import BaseEmbeddingBag
|
from .freq_aware_embedding import FreqAwareEmbeddingBag
|
||||||
from .cache_mgr import CachedParamMgr
|
from .cache_mgr import CachedParamMgr
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
from colossalai.nn._ops._utils import dual_all_to_all
|
from colossalai.nn._ops._utils import dual_all_to_all
|
||||||
|
|
||||||
from colossalai.tensor import ColoParameter, ShardSpec, ComputePattern, ProcessGroup, ColoTensorSpec
|
from colossalai.tensor import ColoParameter, ShardSpec, ComputePattern, ProcessGroup, ColoTensorSpec, ColoTensor
|
||||||
|
|
||||||
|
|
||||||
def get_partition(embedding_dim, rank, world_size) -> Tuple[int, int, bool]:
|
def get_partition(embedding_dim, rank, world_size) -> Tuple[int, int, bool]:
|
||||||
|
@ -29,71 +29,48 @@ def get_partition(embedding_dim, rank, world_size) -> Tuple[int, int, bool]:
|
||||||
return offset, offset + size_list[rank], False
|
return offset, offset + size_list[rank], False
|
||||||
|
|
||||||
|
|
||||||
class ParallelFreqAwareEmbeddingBag(BaseEmbeddingBag):
|
class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag):
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
num_embeddings,
|
|
||||||
embedding_dim,
|
|
||||||
padding_idx=None,
|
|
||||||
max_norm=None,
|
|
||||||
norm_type=2.,
|
|
||||||
scale_grad_by_freq=False,
|
|
||||||
sparse=False,
|
|
||||||
_weight=None,
|
|
||||||
mode='mean',
|
|
||||||
include_last_offset=False,
|
|
||||||
dtype=None,
|
|
||||||
debug=True):
|
|
||||||
super(ParallelFreqAwareEmbeddingBag,
|
|
||||||
self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq,
|
|
||||||
sparse, mode, include_last_offset)
|
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_embeddings,
|
||||||
|
embedding_dim,
|
||||||
|
padding_idx=None,
|
||||||
|
max_norm=None,
|
||||||
|
norm_type=2.,
|
||||||
|
scale_grad_by_freq=False,
|
||||||
|
sparse=False,
|
||||||
|
_weight=None,
|
||||||
|
mode='mean',
|
||||||
|
include_last_offset=False,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
cuda_row_num=0,
|
||||||
|
ids_freq_mapping=None,
|
||||||
|
warmup_ratio=0.7,
|
||||||
|
buffer_size=50_000,
|
||||||
|
):
|
||||||
self.rank = torch.distributed.get_rank()
|
self.rank = torch.distributed.get_rank()
|
||||||
self.world_size = torch.distributed.get_world_size()
|
self.world_size = torch.distributed.get_world_size()
|
||||||
self.debug = debug
|
|
||||||
|
|
||||||
self.partition_start_index, self.partition_end_index, divisible = get_partition(
|
self.partition_start_index, self.partition_end_index, divisible = get_partition(
|
||||||
embedding_dim, self.rank, self.world_size)
|
embedding_dim, self.rank, self.world_size)
|
||||||
self.embedding_dim_per_partition = self.partition_end_index - self.partition_start_index
|
self.embedding_dim_per_partition = self.partition_end_index - self.partition_start_index
|
||||||
|
|
||||||
if _weight is None:
|
super(ParallelFreqAwareEmbeddingBag,
|
||||||
colo_tensor_spec = ColoTensorSpec(pg=ProcessGroup(tp_degree=self.world_size),
|
self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq,
|
||||||
dist_attr=ShardSpec(dims=[-1], num_partitions=[self.world_size]),
|
sparse, _weight, mode, include_last_offset, dtype, device, cuda_row_num, ids_freq_mapping,
|
||||||
compute_attr=ComputePattern.TP1D)
|
warmup_ratio, buffer_size)
|
||||||
self._weight = ColoParameter.from_torch_tensor(torch.empty(self.num_embeddings,
|
|
||||||
self.embedding_dim_per_partition,
|
|
||||||
device='cpu',
|
|
||||||
dtype=dtype),
|
|
||||||
requires_grad=True,
|
|
||||||
spec=colo_tensor_spec)
|
|
||||||
self.init_parameters()
|
|
||||||
else:
|
|
||||||
assert isinstance(_weight, ColoParameter), "initialized weight must in type of ColoParameter"
|
|
||||||
self._weight = _weight
|
|
||||||
|
|
||||||
@property
|
def _weight_alloc(self, dtype, device):
|
||||||
def weight(self):
|
colo_tensor_spec = ColoTensorSpec(pg=ProcessGroup(tp_degree=self.world_size),
|
||||||
return self.cache_weight_mgr.cpu_weight
|
dist_attr=ShardSpec(dims=[-1], num_partitions=[self.world_size]),
|
||||||
|
compute_attr=ComputePattern.TP1D)
|
||||||
def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, Parameter]]:
|
return ColoTensor.from_torch_tensor(torch.empty(self.num_embeddings,
|
||||||
yield 'weight', self.cache_weight_mgr.cuda_cached_weight
|
self.embedding_dim_per_partition,
|
||||||
|
device=device,
|
||||||
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
|
dtype=dtype),
|
||||||
yield self.cache_weight_mgr.cuda_cached_weight
|
spec=colo_tensor_spec)
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def init_parameters(self):
|
|
||||||
self._weight.data.uniform_(-1 / self.num_embeddings, 1 / self.num_embeddings)
|
|
||||||
if self.padding_idx is not None:
|
|
||||||
self._weight[self.padding_idx].fill_(0)
|
|
||||||
|
|
||||||
def preprocess(self,
|
|
||||||
cuda_row_num: int,
|
|
||||||
ids_freq_mapping: Optional[List[int]] = None,
|
|
||||||
warmup_ratio: float = 0.7,
|
|
||||||
buffer_size: int = 50_000):
|
|
||||||
self.cache_weight_mgr = CachedParamMgr(self._weight, cuda_row_num, buffer_size=buffer_size)
|
|
||||||
self.cache_weight_mgr.reorder(ids_freq_mapping, warmup_ratio)
|
|
||||||
|
|
||||||
def forward(self, indices, offsets=None, per_sample_weights=None, shape_hook=None, scatter_dim=0, gather_dim=-1):
|
def forward(self, indices, offsets=None, per_sample_weights=None, shape_hook=None, scatter_dim=0, gather_dim=-1):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
@ -107,29 +84,42 @@ class ParallelFreqAwareEmbeddingBag(BaseEmbeddingBag):
|
||||||
output_shard = shape_hook(output_shard)
|
output_shard = shape_hook(output_shard)
|
||||||
|
|
||||||
output_full = dual_all_to_all(output_shard,
|
output_full = dual_all_to_all(output_shard,
|
||||||
self._weight.get_process_group(),
|
self.weight.get_process_group(),
|
||||||
scatter_dim=scatter_dim,
|
scatter_dim=scatter_dim,
|
||||||
gather_dim=gather_dim)
|
gather_dim=gather_dim)
|
||||||
return output_full
|
return output_full
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls,
|
def from_pretrained(
|
||||||
embedding: torch.Tensor,
|
cls,
|
||||||
freeze: bool = True,
|
embedding: torch.Tensor,
|
||||||
padding_idx: Optional[int] = None,
|
freeze: bool = True,
|
||||||
max_norm: Optional[float] = None,
|
padding_idx: Optional[int] = None,
|
||||||
norm_type: float = 2.,
|
max_norm: Optional[float] = None,
|
||||||
scale_grad_by_freq: bool = False,
|
norm_type: float = 2.,
|
||||||
sparse: bool = False,
|
scale_grad_by_freq: bool = False,
|
||||||
mode: str = 'mean',
|
sparse: bool = False,
|
||||||
include_last_offset: bool = False,
|
mode: str = 'mean',
|
||||||
debug: bool = True,
|
include_last_offset: bool = False,
|
||||||
cuda_row_num: int = 100_000,
|
cuda_row_num: int = 100_000,
|
||||||
ids_freq_mapping: Optional[List[int]] = None,
|
ids_freq_mapping: Optional[List[int]] = None,
|
||||||
warmup_ratio: float = 0.7) -> 'ParallelFreqAwareEmbeddingBag':
|
warmup_ratio: float = 0.7,
|
||||||
|
buffer_size: int = 50_000,
|
||||||
|
) -> 'ParallelFreqAwareEmbeddingBag':
|
||||||
rows, cols = embedding.shape
|
rows, cols = embedding.shape
|
||||||
embedding_bag = cls(rows, cols, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse, embedding, mode,
|
embedding_bag = cls(rows,
|
||||||
include_last_offset, debug)
|
cols,
|
||||||
embedding_bag.preprocess(cuda_row_num, ids_freq_mapping, warmup_ratio)
|
padding_idx,
|
||||||
|
max_norm,
|
||||||
|
norm_type,
|
||||||
|
scale_grad_by_freq,
|
||||||
|
sparse,
|
||||||
|
embedding,
|
||||||
|
mode,
|
||||||
|
include_last_offset,
|
||||||
|
cuda_row_num=cuda_row_num,
|
||||||
|
ids_freq_mapping=ids_freq_mapping,
|
||||||
|
warmup_ratio=warmup_ratio,
|
||||||
|
buffer_size=buffer_size)
|
||||||
embedding_bag.cache_weight_mgr.cuda_cached_weight.requires_grad_ = not freeze
|
embedding_bag.cache_weight_mgr.cuda_cached_weight.requires_grad_ = not freeze
|
||||||
return embedding_bag
|
return embedding_bag
|
||||||
|
|
|
@ -10,7 +10,8 @@ import torch.multiprocessing as mp
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.utils import free_port
|
from colossalai.utils import free_port
|
||||||
from colossalai.testing import rerun_if_address_is_in_use
|
from colossalai.testing import rerun_if_address_is_in_use
|
||||||
from colossalai.tensor import ColoParameter, ProcessGroup, ShardSpec, ComputePattern, ComputeSpec
|
from colossalai.tensor import ColoParameter, ProcessGroup, ShardSpec, ComputePattern, ComputeSpec, \
|
||||||
|
ColoTensor, ColoTensorSpec
|
||||||
from colossalai.nn.parallel.layers import CachedParamMgr, FreqAwareEmbeddingBag, ParallelFreqAwareEmbeddingBag
|
from colossalai.nn.parallel.layers import CachedParamMgr, FreqAwareEmbeddingBag, ParallelFreqAwareEmbeddingBag
|
||||||
|
|
||||||
NUM_EMBED, EMBED_DIM = 10, 8
|
NUM_EMBED, EMBED_DIM = 10, 8
|
||||||
|
@ -99,13 +100,12 @@ def test_reorder_with_freq():
|
||||||
|
|
||||||
def test_freq_aware_embed():
|
def test_freq_aware_embed():
|
||||||
device = torch.device('cuda', 0)
|
device = torch.device('cuda', 0)
|
||||||
model = FreqAwareEmbeddingBag(
|
model = FreqAwareEmbeddingBag(NUM_EMBED,
|
||||||
NUM_EMBED,
|
EMBED_DIM,
|
||||||
EMBED_DIM,
|
mode='mean',
|
||||||
mode='mean',
|
include_last_offset=True,
|
||||||
include_last_offset=True,
|
cuda_row_num=BATCH_SIZE * 2,
|
||||||
).to(device)
|
ids_freq_mapping=None).to(device)
|
||||||
model.preprocess(cuda_row_num=BATCH_SIZE * 2, ids_freq_mapping=None)
|
|
||||||
|
|
||||||
assert model.weight.shape[0] == NUM_EMBED
|
assert model.weight.shape[0] == NUM_EMBED
|
||||||
ref_model = torch.nn.EmbeddingBag.from_pretrained(model.weight.detach().to(device),
|
ref_model = torch.nn.EmbeddingBag.from_pretrained(model.weight.detach().to(device),
|
||||||
|
@ -159,11 +159,11 @@ def run_parallel_freq_aware_embed(rank, world_size):
|
||||||
|
|
||||||
set_seed(4321)
|
set_seed(4321)
|
||||||
weight = torch.rand(num_embed, embed_dim)
|
weight = torch.rand(num_embed, embed_dim)
|
||||||
coloweight = ColoParameter(weight.clone().detach().cpu(), requires_grad=False)
|
coloweight = ColoTensor(weight.clone().detach().cpu(), spec=None)
|
||||||
|
|
||||||
# initialize the tensor spec for the embedding weight parameter,
|
# initialize the tensor spec for the embedding weight parameter,
|
||||||
# which is an ColoParameter.
|
# which is an ColoParameter.
|
||||||
coloweight.process_group = ProcessGroup(tp_degree=world_size)
|
coloweight.set_process_group(ProcessGroup(tp_degree=world_size))
|
||||||
coloweight.set_tensor_spec(ShardSpec(dims=[-1], num_partitions=[world_size]), ComputeSpec(ComputePattern.TP1D))
|
coloweight.set_tensor_spec(ShardSpec(dims=[-1], num_partitions=[world_size]), ComputeSpec(ComputePattern.TP1D))
|
||||||
|
|
||||||
model = ParallelFreqAwareEmbeddingBag.from_pretrained(coloweight,
|
model = ParallelFreqAwareEmbeddingBag.from_pretrained(coloweight,
|
||||||
|
@ -171,12 +171,12 @@ def run_parallel_freq_aware_embed(rank, world_size):
|
||||||
freeze=False,
|
freeze=False,
|
||||||
cuda_row_num=batch_size * 2)
|
cuda_row_num=batch_size * 2)
|
||||||
|
|
||||||
assert model.cache_weight_mgr.cpu_weight.device.type == 'cpu'
|
assert model.cache_weight_mgr.weight.device.type == 'cpu'
|
||||||
assert model.cache_weight_mgr.cuda_cached_weight.requires_grad
|
assert model.cache_weight_mgr.cuda_cached_weight.requires_grad
|
||||||
weight_in_rank = torch.tensor_split(weight, world_size, -1)[rank]
|
weight_in_rank = torch.tensor_split(weight, world_size, -1)[rank]
|
||||||
assert torch.allclose(
|
print(f"model weight: {model.cache_weight_mgr.weight.shape}, ref weight: {weight_in_rank.shape}")
|
||||||
weight_in_rank,
|
assert torch.allclose(weight_in_rank,
|
||||||
model.cache_weight_mgr.cpu_weight.detach()), f"{weight_in_rank - model.cache_weight_mgr.cpu_weight}"
|
model.cache_weight_mgr.weight.detach()), f"{weight_in_rank - model.cache_weight_mgr.weight}"
|
||||||
|
|
||||||
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
|
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
|
||||||
|
|
||||||
|
@ -211,7 +211,7 @@ def run_parallel_freq_aware_embed(rank, world_size):
|
||||||
ref_optimizer.zero_grad()
|
ref_optimizer.zero_grad()
|
||||||
|
|
||||||
model.cache_weight_mgr.flush()
|
model.cache_weight_mgr.flush()
|
||||||
weight_list = gather_tensor(model.cache_weight_mgr.cpu_weight.detach().cuda(), rank, world_size)
|
weight_list = gather_tensor(model.cache_weight_mgr.weight.detach().cuda(), rank, world_size)
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
recover_weight = torch.cat(weight_list, dim=1)
|
recover_weight = torch.cat(weight_list, dim=1)
|
||||||
assert torch.allclose(recover_weight, ref_model.weight.detach()), f"{recover_weight - ref_model.weight}"
|
assert torch.allclose(recover_weight, ref_model.weight.detach()), f"{recover_weight - ref_model.weight}"
|
||||||
|
@ -231,6 +231,6 @@ def test_parallel_freq_aware_embed(world_size):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
# test_cachemgr()
|
||||||
# test_freq_aware_embed()
|
# test_freq_aware_embed()
|
||||||
# test_chunkmgr_admit()
|
|
||||||
test_parallel_freq_aware_embed(2)
|
test_parallel_freq_aware_embed(2)
|
||||||
|
|
Loading…
Reference in New Issue