2022-04-28 09:45:06 +00:00
|
|
|
import torch
|
2022-05-19 04:44:59 +00:00
|
|
|
import torch.nn.functional as F
|
|
|
|
from typing import Optional
|
2022-04-28 09:45:06 +00:00
|
|
|
from colossalai.tensor.op_wrapper import colo_op_impl
|
2022-05-18 06:54:51 +00:00
|
|
|
from colossalai.nn.layer.parallel_1d._utils import reduce_input
|
2022-04-28 09:45:06 +00:00
|
|
|
from colossalai.core import global_context as gpc
|
2022-05-19 04:44:59 +00:00
|
|
|
from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ParallelAction, ColoTensor, distspec
|
|
|
|
from ._utils import GeneralTensor, convert_to_colo_tensor
|
2022-05-13 07:13:52 +00:00
|
|
|
|
2022-04-28 09:45:06 +00:00
|
|
|
|
2022-05-19 04:44:59 +00:00
|
|
|
def colo_embedding_1Dcol(input_tensor: ColoTensor,
|
|
|
|
weight: ColoTensor,
|
|
|
|
padding_idx: Optional[int] = None,
|
|
|
|
max_norm: Optional[float] = None,
|
|
|
|
norm_type: float = 2.0,
|
|
|
|
scale_grad_by_freq: bool = False,
|
|
|
|
sparse: bool = False) -> ColoTensor:
|
2022-04-29 06:10:05 +00:00
|
|
|
# embedding_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P)
|
2022-04-28 09:45:06 +00:00
|
|
|
# Gather splitted lookup table
|
2022-05-16 06:58:08 +00:00
|
|
|
parallel_action = weight.spec.get_action_by_compute_pattern(ComputePattern.TP1D)
|
2022-05-19 04:44:59 +00:00
|
|
|
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.spec.get_process_group()))
|
2022-04-28 09:45:06 +00:00
|
|
|
|
2022-05-19 04:44:59 +00:00
|
|
|
output_parallel = F.embedding(input_tensor,
|
|
|
|
weight,
|
|
|
|
padding_idx=padding_idx,
|
|
|
|
max_norm=max_norm,
|
|
|
|
norm_type=norm_type,
|
|
|
|
scale_grad_by_freq=scale_grad_by_freq,
|
|
|
|
sparse=sparse)
|
2022-05-13 07:13:52 +00:00
|
|
|
output_spec = TensorSpec(
|
2022-05-19 04:44:59 +00:00
|
|
|
distspec.shard(weight.spec.get_process_group(), [-1], [weight.spec.get_process_group_size()]),
|
2022-05-13 07:13:52 +00:00
|
|
|
[ParallelAction(priority=1, parallel_mode=parallel_action.parallel_mode)])
|
2022-05-19 04:44:59 +00:00
|
|
|
output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)
|
|
|
|
output = output.convert_to_dist_spec(distspec.replicate(weight.spec.get_process_group()))
|
2022-04-28 09:45:06 +00:00
|
|
|
return output
|
|
|
|
|
2022-05-13 07:13:52 +00:00
|
|
|
|
2022-05-19 04:44:59 +00:00
|
|
|
def colo_embedding_1Drow(input_tensor: ColoTensor,
|
|
|
|
weight: ColoTensor,
|
|
|
|
padding_idx: Optional[int] = None,
|
|
|
|
max_norm: Optional[float] = None,
|
|
|
|
norm_type: float = 2.0,
|
|
|
|
scale_grad_by_freq: bool = False,
|
|
|
|
sparse: bool = False) -> ColoTensor:
|
2022-04-29 06:10:05 +00:00
|
|
|
# embedding_1Drow split the weight(lookup table) to (num_embeddings/P, embedding_dim)
|
|
|
|
# Find index in this shard and mask those not here
|
|
|
|
# Reduce all
|
2022-05-16 06:58:08 +00:00
|
|
|
parallel_action = weight.spec.get_action_by_compute_pattern(ComputePattern.TP1D)
|
2022-05-19 04:44:59 +00:00
|
|
|
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.spec.get_process_group()))
|
2022-05-13 07:13:52 +00:00
|
|
|
|
2022-04-29 06:10:05 +00:00
|
|
|
tensor_parallel_rank = gpc.get_local_rank(parallel_action.parallel_mode)
|
|
|
|
num_embeddings_per_partition = weight.size(0)
|
|
|
|
vocab_start_index = tensor_parallel_rank * num_embeddings_per_partition
|
|
|
|
vocab_end_index = vocab_start_index + num_embeddings_per_partition
|
|
|
|
|
|
|
|
# Build the mask.
|
2022-05-19 04:44:59 +00:00
|
|
|
input_mask = (input_tensor < vocab_start_index) | \
|
|
|
|
(input_tensor >= vocab_end_index)
|
2022-04-29 06:10:05 +00:00
|
|
|
# Mask the input.
|
|
|
|
# TODO(jzy) masked_input may be an activation managed by ColoTensor.
|
2022-05-19 04:44:59 +00:00
|
|
|
masked_input = input_tensor.clone() - vocab_start_index
|
2022-04-29 06:10:05 +00:00
|
|
|
masked_input[input_mask] = 0
|
|
|
|
|
2022-05-19 04:44:59 +00:00
|
|
|
partial_output = F.embedding(masked_input,
|
|
|
|
weight,
|
|
|
|
padding_idx=padding_idx,
|
|
|
|
max_norm=max_norm,
|
|
|
|
norm_type=norm_type,
|
|
|
|
scale_grad_by_freq=scale_grad_by_freq,
|
|
|
|
sparse=sparse)
|
2022-04-29 06:10:05 +00:00
|
|
|
|
|
|
|
# Mask the output embedding.
|
|
|
|
partial_output[input_mask, :] = 0.
|
|
|
|
# Reduce across all the model parallel GPUs.
|
|
|
|
output = reduce_input(partial_output, parallel_action.parallel_mode)
|
2022-05-19 04:44:59 +00:00
|
|
|
output = ColoTensor.from_torch_tensor(output, spec=TensorSpec(distspec.replicate(weight.spec.get_process_group())))
|
2022-04-29 06:10:05 +00:00
|
|
|
return output
|
|
|
|
|
2022-05-13 07:13:52 +00:00
|
|
|
|
2022-05-19 04:44:59 +00:00
|
|
|
def colo_embedding_1d(mode: str,
|
|
|
|
input_tensor: ColoTensor,
|
|
|
|
weight: ColoTensor,
|
|
|
|
padding_idx: Optional[int] = None,
|
|
|
|
max_norm: Optional[float] = None,
|
|
|
|
norm_type: float = 2.0,
|
|
|
|
scale_grad_by_freq: bool = False,
|
|
|
|
sparse: bool = False) -> ColoTensor:
|
|
|
|
assert mode in ('row', 'col')
|
|
|
|
funcs = {'row': colo_embedding_1Drow, 'col': colo_embedding_1Dcol}
|
|
|
|
return funcs[mode](input_tensor,
|
|
|
|
weight,
|
|
|
|
padding_idx=padding_idx,
|
|
|
|
max_norm=max_norm,
|
|
|
|
norm_type=norm_type,
|
|
|
|
scale_grad_by_freq=scale_grad_by_freq,
|
|
|
|
sparse=sparse)
|
|
|
|
|
|
|
|
|
|
|
|
@colo_op_impl(F.embedding)
|
|
|
|
def colo_embedding(input_tensor: GeneralTensor,
|
|
|
|
weight: GeneralTensor,
|
|
|
|
padding_idx: Optional[int] = None,
|
|
|
|
max_norm: Optional[float] = None,
|
|
|
|
norm_type: float = 2.0,
|
|
|
|
scale_grad_by_freq: bool = False,
|
|
|
|
sparse: bool = False):
|
2022-04-28 09:45:06 +00:00
|
|
|
"""Handles ``__torch_function__`` dispatch for ``torch.nn.functional.embedding``.
|
|
|
|
This method looks up an embedding table.
|
|
|
|
"""
|
2022-05-19 04:44:59 +00:00
|
|
|
input_tensor, weight = tuple(map(convert_to_colo_tensor, (input_tensor, weight)))
|
2022-05-13 07:13:52 +00:00
|
|
|
|
2022-04-28 09:45:06 +00:00
|
|
|
# Handle differen parallel actions.
|
2022-05-16 06:58:08 +00:00
|
|
|
|
2022-05-13 07:13:52 +00:00
|
|
|
if not weight.has_spec(): # No Model Parallel Applied
|
2022-05-16 06:58:08 +00:00
|
|
|
assert weight.spec.is_gathered(), 'Invalid weight spec for native embedding op'
|
2022-05-19 04:44:59 +00:00
|
|
|
return ColoTensor.from_torch_tensor(
|
|
|
|
F.embedding(input_tensor,
|
|
|
|
weight,
|
|
|
|
padding_idx=padding_idx,
|
|
|
|
max_norm=max_norm,
|
|
|
|
norm_type=norm_type,
|
|
|
|
scale_grad_by_freq=scale_grad_by_freq,
|
|
|
|
sparse=sparse))
|
2022-05-16 06:58:08 +00:00
|
|
|
elif weight.spec.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
|
|
|
|
if weight.spec.is_1D_row():
|
2022-05-19 04:44:59 +00:00
|
|
|
mode = 'row'
|
2022-05-16 06:58:08 +00:00
|
|
|
elif weight.spec.is_1D_col():
|
2022-05-19 04:44:59 +00:00
|
|
|
mode = 'col'
|
2022-04-28 09:45:06 +00:00
|
|
|
else:
|
|
|
|
raise NotImplementedError
|
2022-05-19 04:44:59 +00:00
|
|
|
return colo_embedding_1d(mode,
|
|
|
|
input_tensor,
|
|
|
|
weight,
|
|
|
|
padding_idx=padding_idx,
|
|
|
|
max_norm=max_norm,
|
|
|
|
norm_type=norm_type,
|
|
|
|
scale_grad_by_freq=scale_grad_by_freq,
|
|
|
|
sparse=sparse)
|
2022-04-28 09:45:06 +00:00
|
|
|
else:
|
|
|
|
raise NotImplementedError
|