adapted for sequence parallel (#163)

pull/171/head
Frank Lee 2022-01-20 13:44:51 +08:00 committed by GitHub
parent a2e649da39
commit e2089c5c15
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 432 additions and 119 deletions

View File

@ -32,7 +32,7 @@ class NaiveAMPOptimizer(ColossalaiOptimizer):
loss.backward()
def step(self):
self.optim.step()
return self.optim.step()
def clip_grad_norm(self, model: nn.Module, max_norm: float):
pass

View File

@ -26,6 +26,7 @@ class ParallelMode(Enum):
# sequence parallel
SEQUENCE = 'sequence'
SEQUENCE_DP = 'sequence_dp'
# 1D Parallel
PARALLEL_1D = '1d'

View File

@ -1,5 +1,6 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch.distributed as dist
from colossalai.registry import DIST_GROUP_INITIALIZER
from .initializer_tensor import Initializer_Tensor
@ -7,6 +8,43 @@ from .process_group_initializer import ProcessGroupInitializer
from ..parallel_mode import ParallelMode
@DIST_GROUP_INITIALIZER.register_module
class Initializer_Sequence_DP(ProcessGroupInitializer):
'''A ProcessGroupInitializer for sequence parallelism all-reduce.
In Sequence Parallelism, each GPU holds the full copy of model weights,
thus, gradient all-reduce occurs across all processes in the same pipeline stage
'''
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.dp_size = self.world_size // self.pipeline_parallel_size
self.num_group = self.pipeline_parallel_size
def init_dist_group(self):
'''Initialize Sequence Parallel process groups used for gradient all-reduce.
:return: (local_rank, group_world_size, process_group, ranks_in_group, mode)
:rtype: tuple
'''
local_rank = None
ranks_in_group = None
process_group = None
group_world_size = None
mode = ParallelMode.SEQUENCE_DP
for i in range(self.num_group):
ranks = [i * self.dp_size + j for j in range(self.dp_size)]
group = dist.new_group(ranks)
if self.rank in ranks:
local_rank = ranks.index(self.rank)
group_world_size = len(ranks)
process_group = group
ranks_in_group = ranks
return local_rank, group_world_size, process_group, ranks_in_group, mode
@DIST_GROUP_INITIALIZER.register_module
class Initializer_Sequence(ProcessGroupInitializer):
'''A ProcessGroupInitializer for sequence parallelism.
@ -15,13 +53,27 @@ class Initializer_Sequence(ProcessGroupInitializer):
def __init__(self,
*args, **kwargs):
super().__init__(*args, **kwargs)
# reuse tensor parallel code
self._initializer = Initializer_Tensor(*args, **kwargs)
# reuse tensor parallel initializer code
self._sequence_initializer = Initializer_Tensor(*args, **kwargs)
self._sequence_dp_initializer = Initializer_Sequence_DP(*args, **kwargs)
def init_dist_group(self):
local_rank, group_world_size, process_group, ranks_in_group, mode = self._initializer.init_dist_group()
'''Initialize Sequence parallel process groups and assign local_ranks and groups to each gpu.
Sequence parallelism requires 2 process groups. The first is for model forward where several processes
exchange paritial query, key and value embedding to compute self attention values. The second is for
all-reduce to synchronize the model parameters.
:return: 2D tensor parallelism's information
:rtype: list of tuples (local_rank, group_world_size, process_group, ranks_in_group, mode)
'''
parallel_setting = []
local_rank, group_world_size, process_group, ranks_in_group, mode = self._sequence_initializer.init_dist_group()
# change mode to sequence
mode = ParallelMode.SEQUENCE
return local_rank, group_world_size, process_group, ranks_in_group, mode
parallel_setting.append((local_rank, group_world_size, process_group, ranks_in_group, mode))
parallel_setting.append(self._sequence_dp_initializer.init_dist_group())
return parallel_setting

View File

@ -77,7 +77,7 @@ class Engine:
"""
self._all_reduce_gradients()
self.optimizer.clip_grad_norm(self.model, self._clip_grad_norm)
self.optimizer.step()
return self.optimizer.step()
def backward(self, loss: Tensor):
"""Start backward propagation given the loss value computed by a loss function

View File

@ -1,9 +1,12 @@
from ._base_gradient_handler import BaseGradientHandler
from ._data_parallel_gradient_handler import DataParallelGradientHandler
from ._zero_gradient_handler import ZeROGradientHandler
from ._sequence_parallel_gradient_handler import SequenceParallelGradientHandler
from ._pipeline_parallel_gradient_handler import PipelineSharedModuleGradientHandler
from ._moe_gradient_handler import MoeGradientHandler
from ._sequence_parallel_gradient_handler import SequenceParallelGradientHandler
__all__ = ['BaseGradientHandler', 'DataParallelGradientHandler',
'ZeROGradientHandler', 'PipelineSharedModuleGradientHandler',
'MoeGradientHandler']
'MoeGradientHandler', 'SequenceParallelGradientHandler']

View File

@ -0,0 +1,51 @@
#!/usr/bin/env python
from functools import total_ordering
import torch
import torch.distributed as dist
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from colossalai.core import global_context as gpc
from colossalai.registry import GRADIENT_HANDLER
from ._base_gradient_handler import BaseGradientHandler
from ...context.parallel_mode import ParallelMode
import colossalai
@GRADIENT_HANDLER.register_module
class SequenceParallelGradientHandler(BaseGradientHandler):
"""A helper class to handle all-reduce operations in a data parallel group.
A all-reduce collective communication will be operated in
:func:`handle_gradient` among a data parallel group.
For better performance, it bucketizes the gradients of all parameters that are
the same type to improve the efficiency of communication.
"""
def handle_gradient(self):
"""A method running a all-reduce operation in a data parallel group.
"""
# bucketize and all-reduce
buckets = {}
# Pack the buckets.
for param in self._model.parameters():
if param.requires_grad and param.grad is not None:
tp = param.data.type()
if tp not in buckets:
buckets[tp] = []
buckets[tp].append(param)
# For each bucket, all-reduce and copy all-reduced grads.
for tp in buckets:
bucket = buckets[tp]
grads = [param.grad.data for param in bucket]
coalesced = _flatten_dense_tensors(grads)
coalesced /= gpc.get_world_size(ParallelMode.SEQUENCE_DP)
dist.all_reduce(
coalesced, group=gpc.get_group(ParallelMode.SEQUENCE_DP))
for buf, synced in zip(grads, _unflatten_dense_tensors(
coalesced, grads)):
buf.copy_(synced)

View File

@ -222,7 +222,6 @@ class PipelineSchedule(BaseSchedule):
assert forward_only or return_loss, \
'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.'
self.load_batch(data_iter)
num_warmup_microbatches = \
(gpc.get_world_size(ParallelMode.PIPELINE) -

View File

@ -17,7 +17,7 @@ from colossalai.core import global_context as gpc
from colossalai.engine import Engine
from colossalai.logging import get_dist_logger
from colossalai.utils import (accumulate_gradient, get_current_device,
sync_model_param_in_dp, is_using_ddp, is_using_pp)
sync_model_param, is_using_ddp, is_using_pp, is_using_sequence)
from colossalai.zero import convert_to_zero, ZeroRedundancyOptimizer_Level_2, ZeroRedundancyOptimizer_Level_3
from colossalai.builder.builder import build_gradient_handler
from torch.optim.optimizer import Optimizer
@ -187,7 +187,7 @@ def launch_from_torch(config: Union[str, Path, Config, Dict],
backend: str = 'nccl',
seed: int = 1024,
verbose: bool = True):
'''A wrapper for colossalai.launch for torchrun or torch.distributed.launch by reading rank and world size
'''A wrapper for colossalai.launch for torchrun or torch.distributed.launch by reading rank and world size
from the environment variables set by PyTorch
:param config: config file or config file path are both acceptable
@ -270,12 +270,15 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
model.to(get_current_device())
use_zero3 = hasattr(gpc.config, 'zero') and gpc.config.zero.level == 3
if not moe_env.is_initialized() and not use_zero3:
sync_model_param_in_dp(model)
if is_using_sequence():
sync_model_param(model, ParallelMode.SEQUENCE_DP)
elif is_using_ddp():
sync_model_param(model, ParallelMode.DATA)
else:
print(
"Warning: The parameters of models is not automatically synchronized.\n"
logger.warning(
"The parameters of models is not automatically synchronized.\n"
"Please make sure that all parameters are the same in data parallel group.",
flush=True)
ranks=[0])
# check amp and zero
fp16_cfg = gpc.config.get('fp16', None)
@ -339,11 +342,16 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
"Data parallel training is detected with moe parallel, MoeGradientHandler is automatically "
"added even though not specified in the configuration",
ranks=[0])
elif is_using_sequence():
model = DDP(model, process_group=gpc.get_group(ParallelMode.SEQUENCE_DP))
if verbose:
logger.info(
'Model is using torch.nn.parallel.DistributedDataParallel for Sequence Parallelism', ranks=[0])
elif is_using_ddp() and not is_using_pp() and amp_mode != AMP_TYPE.NAIVE:
model = DDP(model, process_group=gpc.get_group(ParallelMode.DATA))
if verbose:
logger.info(
'Model is using torch.nn.parallel.DistributedDataParallel', ranks=[0])
'Model is using torch.nn.parallel.DistributedDataParallel for Data Parallelism', ranks=[0])
elif is_using_ddp():
gradient_handler_cfg = [dict(type='DataParallelGradientHandler')]
if verbose:

View File

@ -6,6 +6,7 @@ import numbers
import torch
from torch.nn.parameter import Parameter
from torch.nn import init
from torch.cuda.amp import custom_fwd, custom_bwd
import importlib
global colossal_layer_norm_cuda
@ -15,6 +16,7 @@ colossal_layer_norm_cuda = None
class FusedLayerNormAffineFunction(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
def forward(ctx, input, weight, bias, normalized_shape, eps):
ctx.normalized_shape = normalized_shape
@ -29,6 +31,7 @@ class FusedLayerNormAffineFunction(torch.autograd.Function):
return output
@staticmethod
@custom_bwd
def backward(ctx, grad_output):
input_, weight_, bias_, mean, invvar = ctx.saved_tensors
@ -71,3 +74,6 @@ class MixedFusedLayerNorm(torch.nn.Module):
return FusedLayerNormAffineFunction.apply(input, self.weight, self.bias,
self.normalized_shape, self.eps)
def __repr__(self):
return f'MixedFusedLayerNorm(normalized_shape={self.normalized_shape}, eps={self.eps})'

View File

@ -6,7 +6,7 @@ JIT_OPTIONS_SET = False
def set_jit_fusion_options():
"""Set PyTorch JIT layer fusion options.
"""
# LSG: the latest pytorch and CUDA versions may not support
# LSG: the latest pytorch and CUDA versions may not support
# the following jit settings
global JIT_OPTIONS_SET
if JIT_OPTIONS_SET == False:

View File

@ -9,6 +9,7 @@ from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.nn.layer.parallel_sequence._utils import _calc_incoming_device_range, _calc_current_device_range
from colossalai.utils import get_current_device
from torch.cuda.amp import custom_bwd, custom_fwd
class RingQK(torch.autograd.Function):
@ -17,6 +18,7 @@ class RingQK(torch.autograd.Function):
"""
@staticmethod
@custom_fwd
def forward(ctx,
sub_q,
sub_k,
@ -54,6 +56,7 @@ class RingQK(torch.autograd.Function):
return attention_score
@staticmethod
@custom_bwd
def backward(ctx, grad_output):
sub_q, sub_k, = ctx.saved_tensors
local_rank = gpc.get_local_rank(ParallelMode.SEQUENCE)
@ -64,6 +67,7 @@ class RingQK(torch.autograd.Function):
grad_output.transpose(2, 1),
sub_q
)
dist.all_reduce(grad_k, group=gpc.get_group(ParallelMode.SEQUENCE))
grad_k = grad_k[:, local_rank * ctx.sub_seq_length: (local_rank + 1) * ctx.sub_seq_length]
grad_k /= local_world_size
@ -94,6 +98,7 @@ class RingAV(torch.autograd.Function):
"""
@staticmethod
@custom_fwd
def forward(ctx,
attention_score,
sub_v,
@ -131,6 +136,7 @@ class RingAV(torch.autograd.Function):
return sub_attention_result
@staticmethod
@custom_bwd
def backward(ctx, grad_output):
local_rank = gpc.get_local_rank(ParallelMode.SEQUENCE)
local_world_size = gpc.get_world_size(ParallelMode.SEQUENCE)

View File

@ -2,15 +2,20 @@
# -*- encoding: utf-8 -*-
import math
import colossalai
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.nn.layer.parallel_sequence._operation import RingQK, RingAV
from colossalai.registry import LAYERS
from colossalai.kernel.cuda_native.scaled_softmax import AttnMaskType
from colossalai.kernel import FusedScaleMaskSoftmax
from colossalai.context import seed
@LAYERS.register_module
@ -31,136 +36,144 @@ class TransformerSelfAttentionRing(nn.Module):
def __init__(self,
hidden_size,
kv_channels,
num_attention_heads,
attention_dropout,
attention_mask_func,
layer_number,
apply_query_key_layer_scaling: bool = False,
convert_fp16_to_fp32_in_softmax: bool = False,
attn_mask_type=AttnMaskType.padding,
masked_softmax_fusion=True,
fp16=False,
bf16=False
):
super().__init__()
self.convert_fp16_to_fp32_in_softmax = convert_fp16_to_fp32_in_softmax
self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
self.attention_mask_func = attention_mask_func
self.layer_number = layer_number
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
self.attn_mask_type = attn_mask_type
assert self.layer_number > 0
self.attention_dropout = attention_dropout
projection_size = kv_channels * num_attention_heads
self.hidden_size_per_attention_head = projection_size // num_attention_heads
if self.apply_query_key_layer_scaling:
self.convert_fp16_to_fp32_in_softmax = True
assert self.hidden_size % self.num_attention_heads == 0, \
'hidden size is not divisible by the number of attention heads'
self.hidden_size_per_attention_head = self.hidden_size // num_attention_heads
self.world_size = gpc.get_world_size(ParallelMode.SEQUENCE)
# Strided linear layer.
self.query_key_value = nn.Linear(
self.query_key_value = _Linear(
hidden_size,
3 * projection_size,
3 * self.hidden_size,
)
# coeff = None
self.coeff = None
self.norm_factor = math.sqrt(self.hidden_size)
# TODO: add apply_query_key_layer_scaling when we have the kernel module
# if self.apply_query_key_layer_scaling:
# coeff = self.layer_number
# self.norm_factor *= coeff
if self.apply_query_key_layer_scaling:
self.coeff = layer_number
self.norm_factor *= self.coeff
# TODO: add fused scale mask softmax kernel when we have the kernel module
# self.scale_mask_softmax = FusedScaleMaskSoftmax(
# self.fp16, self.bf16,
# self.attn_mask_type,
# masked_softmax_fusion,
# attention_mask_func,
# self.attention_softmax_in_fp32,
# coeff)
self.scale_mask_softmax = FusedScaleMaskSoftmax(
fp16, bf16,
self.attn_mask_type,
masked_softmax_fusion,
self.attention_mask_func,
self.convert_fp16_to_fp32_in_softmax,
self.coeff)
self.attention_dropout = nn.Dropout(attention_dropout)
# Output.
self.dense = nn.Linear(
projection_size,
hidden_size,
bias=True)
self.dense = _Linear(hidden_size,
hidden_size,
bias=True,
skip_bias_add=True)
def forward(self, hidden_states, attention_mask):
# hidden_states: [sq, b, h]
# hidden_states: [sub_seq_len, batch_size, hidden_size]
# attention_mask: [batch_size, 1, sub_seq_len, seq_len]
sub_seq_length, batch_size, hidden_size = hidden_states.size()
# =====================
# Query, Key, and Value
# =====================
# Attention heads [sq, b, h] --> [sq, b, (3 * hn * num_heads)]
# Attention heads shape change:
# [sub_seq_len, batch_size, hidden_size] --> [sub_seq_len, batch_size, (3 * head_size * num_heads)]
mixed_x_layer = self.query_key_value(hidden_states)
# [sq, b, num_heads, 3 * hn] --> 3 [sq, b, num_heads, hn]
# [sub_seq_len, batch_size, num_heads, 3 * head_size] --> 3 [sub_seq_len, batch_size, num_heads, head_size]
new_tensor_shape = mixed_x_layer.size()[:-1] + (self.num_attention_heads,
3 * self.hidden_size_per_attention_head)
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
# split into query, key and value
last_dim = mixed_x_layer.dim() - 1
last_dim_value = mixed_x_layer.size()[-1]
last_dim_value = mixed_x_layer.size(-1)
assert last_dim_value % 3 == 0, 'the last dimension is not a multiple of 3, ' \
'cannot be divided into query, key and value'
partition_size = last_dim_value // 3
(query_layer, key_layer, value_layer) = torch.split(
mixed_x_layer, partition_size, dim=last_dim)
# ===================================
# Raw attention scores. [b, num_heads, s, s]
# ===================================
# [b, num_heads, sq, sk]
# attention scores: [batch_size, num_heads, sub_seq_len, seq_len]
output_size = (query_layer.size(1),
query_layer.size(2),
query_layer.size(0),
key_layer.size(0) * self.world_size)
# [sq, b, num_heads, hn] -> [sq, b * num_heads, hn]
# [sub_seq_len, batch_size, num_heads, head_size] -> [sub_seq_len, batch_size * num_heads, head_size]
query_layer = query_layer.view(output_size[2],
output_size[0] * output_size[1], -1)
# [sk, b, num_heads, hn] -> [sk, b * num_heads, hn]
# [sub_seq_len, batch_size, num_heads, head_size] -> [sub_seq_len, batch_size * num_heads, head_size]
key_layer = key_layer.view(key_layer.size(0),
output_size[0] * output_size[1], -1)
# [b, sq, sk]
# attention_scores: [batch_size * num_heads, sub_seq_len, seq_len]
attention_scores = RingQK.apply(
# [b * num_heads, sq, hn]
query_layer.transpose(0, 1).contiguous(),
key_layer.transpose(0, 1).contiguous(), # [b * num_heads, sk, hn],
query_layer.transpose(0, 1).contiguous(), # [batch_size * num_heads, sub_seq_len, head_size]
key_layer.transpose(0, 1).contiguous(), # [batch_size * num_heads, sub_seq_len, head_size],
batch_size,
self.num_attention_heads,
sub_seq_length
)
attention_scores /= self.norm_factor
# change view to [b, num_heads, sq, sk]
# change view to [batch_size, num_heads, sub_seq_len, seq_len]
attention_scores = attention_scores.view(*output_size)
attention_scores = attention_scores.unsqueeze(1)
attention_scores = attention_scores + attention_mask
attention_probs = F.softmax(attention_scores, dim=-1)
attention_probs = attention_probs.squeeze(1)
# change shape to [batch_size, num_heads, sub_seq_len, seq_len]
attention_probs = self.scale_mask_softmax(attention_scores, attention_mask)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
# with mpu.get_cuda_rng_tracker().fork():
# TODO: check if a rng tracker is needed
attention_probs = self.attention_dropout(attention_probs)
with seed(ParallelMode.TENSOR):
attention_probs = self.attention_dropout(attention_probs)
# context layer shape: [b, num_heads, sq, hn]
# context layer shape: [batch_size, num_heads, sub_seq_len, head_size]
output_size = (value_layer.size(1),
value_layer.size(2),
query_layer.size(0),
value_layer.size(3))
#
# # change view [sk, b * num_heads, hn]
# change view [sub_seq_len, batch_size * num_heads, head_size]
value_layer = value_layer.contiguous().view(value_layer.size(0),
output_size[0] * output_size[1], -1)
# # change view [b * num_heads, sq, sk]
# # change view [b * num_heads, sub_seq_len, seq_len]
attention_probs = attention_probs.view(attention_probs.size(0) * attention_probs.size(1),
attention_probs.size(2),
attention_probs.size(3))
# matmul: [b*num_heads, sq, hn]
# context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
# matmul: [batch_size * num_heads, sub_seq_len, head_size]
context_layer = RingAV.apply(
attention_probs,
value_layer.transpose(0, 1).contiguous(),
@ -170,19 +183,83 @@ class TransformerSelfAttentionRing(nn.Module):
sub_seq_length
)
# # change view [b, num_heads, sq, hn]
# change view [batch_size, num_heads, sub_seq_len, head_size]
context_layer = context_layer.view(*output_size)
# # [b, np, sq, hn] --> [sq, b, np, hn]
# [batch_size, num_heads, sub_seq_len, head_size] -> [sub_seq_len, batch_size, num_heads, head_size]
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
# # [sq, b, np, hn] --> [sq, b, hp]
# [sub_seq_len, batch_size, num_heads, head_size] -> [sub_seq_len, batch_size, hidden_size]
new_context_layer_shape = context_layer.size()[:-2] + (
self.hidden_size_per_attention_head * self.num_attention_heads,)
context_layer = context_layer.view(*new_context_layer_shape)
# context_layer = context_layer.transpose(1, 0).contiguous()
output = self.dense(context_layer)
bias = self.dense.bias
output, bias = self.dense(context_layer)
return output, bias
def __repr__(self):
return f'TransformerSelfAttentionRing(apply_query_key_layer_scaling={self.apply_query_key_layer_scaling}, ' \
f'layer_number={self.layer_number}, hidden_size:{self.hidden_size}, attention_dropout={self.attention_dropout}, ' \
f'attn_mask_type={self.attn_mask_type}, num_attention_heads={self.num_attention_heads}, ' \
f'hidden_size_per_attention_head={self.hidden_size_per_attention_head}, coeff={self.coeff}, norm_factor={self.norm_factor}, ' \
f'convert_fp16_to_fp32_in_softmax={self.convert_fp16_to_fp32_in_softmax})'
class _Linear(nn.Module):
"""Linear layer with column parallelism.
The linear layer is defined as Y = XA + b. A is parallelized along
its second dimension as A = [A_1, ..., A_p].
Arguments:
input_size: first dimension of matrix A.
output_size: second dimension of matrix A.
bias: If true, add bias
init_method: method to initialize weights. Note that bias is always set
to zero.
stride: For the strided linear layers.
keep_master_weight_for_test: This was added for testing and should be
set to False. It returns the master weights
used for initialization.
skip_bias_add: This was added to enable performance optimations where bias
can be fused with other elementwise operations. we skip
adding bias but instead return it.
"""
def __init__(self,
input_size,
output_size,
bias=True,
skip_bias_add=False):
super(_Linear, self).__init__()
# Keep input parameters
self.input_size = input_size
self.output_size = output_size
self.skip_bias_add = skip_bias_add
self.weight = Parameter(torch.empty(self.output_size,
self.input_size,
))
nn.init.xavier_normal_(self.weight)
if bias:
self.bias = Parameter(torch.empty(self.output_size))
# Always initialize bias to zero.
with torch.no_grad():
self.bias.zero_()
else:
self.register_parameter('bias', None)
def forward(self, input_):
# Matrix multiply.
bias = self.bias if not self.skip_bias_add else None
output = F.linear(input_, self.weight, bias)
if self.skip_bias_add:
return output, self.bias
else:
return output
def __repr__(self):
return f'Linear(in_features={self.input_size}, out_features={self.output_size}, ' + \
f'bias={self.bias is not None}, skip_bias_add={self.skip_bias_add})'

View File

@ -1,8 +1,8 @@
from .activation_checkpoint import checkpoint
from .common import (clip_grad_norm_fp32, conditional_context, copy_tensor_parallel_attributes, count_zeros_fp32,
free_port, is_dp_rank_0, is_model_parallel_parameter, is_no_pp_or_last_stage, is_tp_rank_0,
is_using_ddp, is_using_pp, multi_tensor_applier, param_is_not_tensor_parallel_duplicate,
print_rank_0, switch_virtual_pipeline_parallel_rank, sync_model_param_in_dp)
is_using_ddp, is_using_pp, is_using_sequence, multi_tensor_applier, param_is_not_tensor_parallel_duplicate,
print_rank_0, switch_virtual_pipeline_parallel_rank, sync_model_param)
from .cuda import empty_cache, get_current_device, set_to_cuda, synchronize
from .data_sampler import DataParallelSampler, get_dataloader
from .gradient_accumulation import accumulate_gradient
@ -10,9 +10,9 @@ from .memory import report_memory_usage
from .timer import MultiTimer, Timer
__all__ = [
'checkpoint', 'free_port', 'print_rank_0', 'sync_model_param_in_dp', 'is_dp_rank_0', 'is_tp_rank_0',
'is_no_pp_or_last_stage', 'is_using_ddp', 'is_using_pp', 'conditional_context', 'is_model_parallel_parameter',
'clip_grad_norm_fp32', 'count_zeros_fp32', 'copy_tensor_parallel_attributes',
'checkpoint', 'free_port', 'print_rank_0', 'sync_model_param', 'is_dp_rank_0', 'is_tp_rank_0',
'is_no_pp_or_last_stage', 'is_using_ddp', 'is_using_pp', 'is_using_sequence', 'conditional_context',
'is_model_parallel_parameter', 'clip_grad_norm_fp32', 'count_zeros_fp32', 'copy_tensor_parallel_attributes',
'param_is_not_tensor_parallel_duplicate', 'get_current_device', 'synchronize', 'empty_cache', 'set_to_cuda',
'report_memory_usage', 'Timer', 'MultiTimer', 'multi_tensor_applier', 'accumulate_gradient', 'DataParallelSampler',
'get_dataloader', 'switch_virtual_pipeline_parallel_rank'

View File

@ -47,16 +47,16 @@ def free_port():
continue
def sync_model_param_in_dp(model):
def sync_model_param(model, parallel_mode):
'''Make sure data parameters are consistent during Data Parallel Mode
:param model: A pyTorch nn.model on whose parameters you check the consistency
'''
if gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(ParallelMode.DATA) > 1:
if gpc.is_initialized(parallel_mode) and gpc.get_world_size(parallel_mode) > 1:
for param in model.parameters():
ranks = gpc.get_ranks_in_group(ParallelMode.DATA)
ranks = gpc.get_ranks_in_group(parallel_mode)
dist.broadcast(
param, src=ranks[0], group=gpc.get_group(ParallelMode.DATA))
param, src=ranks[0], group=gpc.get_group(parallel_mode))
def is_dp_rank_0():
@ -79,6 +79,10 @@ def is_using_pp():
return gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1
def is_using_sequence():
return gpc.is_initialized(ParallelMode.SEQUENCE) and gpc.get_world_size(ParallelMode.SEQUENCE) > 1
@contextmanager
def conditional_context(context_manager, enable=True):
if enable:
@ -240,16 +244,20 @@ def count_zeros_fp32(parameters):
num_zeros = grad.numel() - torch.count_nonzero(grad)
total_num_zeros = num_zeros + total_num_zeros
total_num_zeros = torch.IntTensor([int(total_num_zeros)]).cuda()
# Sum across all model-parallel GPUs.
ops = []
ops.append(dist.all_reduce(total_num_zeros,
op=dist.ReduceOp.SUM,
group=gpc.get_group(ParallelMode.TENSOR),
async_op=True))
ops.append(dist.all_reduce(total_num_zeros,
op=dist.ReduceOp.SUM,
group=gpc.get_group(ParallelMode.PIPELINE),
async_op=True))
if gpc.is_initialized(ParallelMode.PIPELINE):
ops.append(dist.all_reduce(total_num_zeros,
op=dist.ReduceOp.SUM,
group=gpc.get_group(ParallelMode.PIPELINE),
async_op=True))
for req in ops:
req.wait()
total_num_zeros = total_num_zeros.item()

View File

@ -40,8 +40,6 @@ def report_memory_usage(message, logger=None, report_cpu=False):
:type report_cpu: bool
:raises EnvironmentError: raise error if no distributed environment has been initialized
'''
if not gpc.is_initialized(ParallelMode.GLOBAL):
raise EnvironmentError("No distributed environment is initialized")
gpu_allocated = bytes_to_MB(torch.cuda.memory_allocated())
gpu_max_allocated = bytes_to_MB(torch.cuda.max_memory_allocated())

View File

@ -1,6 +1,7 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import time
from typing import Tuple
from .cuda import synchronize
@ -8,6 +9,7 @@ class Timer:
'''
A timer object which helps to log the execution times, and provides different tools to assess the times.
'''
def __init__(self):
self._started = False
self._start_time = time.time()
@ -129,6 +131,6 @@ class MultiTimer:
def set_status(self, mode: bool):
self._on = mode
def __iter__(self):
def __iter__(self) -> Tuple[str, Timer]:
for name, timer in self._timers.items():
yield name, timer
yield name, timer

View File

@ -1,48 +1,150 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import pytest
import colossalai
import colossalai.nn as col_nn
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from colossalai.initialize import launch
from colossalai.logging import get_dist_logger
from checks_seq.check_layer_seq import *
import pytest
from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode
from functools import partial
from colossalai.utils import free_port
CONFIG = dict(
parallel=dict(
pipeline=1,
tensor=dict(mode='sequence', size=4)
tensor=dict(size=4, mode='sequence')
)
)
def check_layer():
check_selfattention()
def check_ring_qk(rank, world_size):
# params
batch_size = 4
num_heads = 4
seq_length = 32
attention_head_size = 32
sub_seq_length = seq_length // world_size
# create master tensors
q = torch.rand(batch_size*num_heads, seq_length, attention_head_size).cuda()
k = torch.rand(batch_size*num_heads, seq_length, attention_head_size).cuda()
dist.broadcast(q, src=0, group=gpc.get_group(ParallelMode.SEQUENCE))
dist.broadcast(k, src=0, group=gpc.get_group(ParallelMode.SEQUENCE))
# create distributed tensors
sub_q = q.clone()[:, rank*sub_seq_length:(rank+1)*sub_seq_length].contiguous()
sub_k = k.clone()[:, rank*sub_seq_length:(rank+1)*sub_seq_length].contiguous()
# set autograd attributes
q.requires_grad = True
k.requires_grad = True
q.retain_grad()
k.retain_grad()
sub_q.requires_grad = True
sub_k.requires_grad = True
sub_q.retain_grad()
sub_k.retain_grad()
# compute master attention scores
a = torch.matmul(q, k.transpose(2, 1))
# compute distributed attention scores
ring_qk = colossalai.nn.layer.parallel_sequence.RingQK.apply
sub_a = ring_qk(sub_q, sub_k, batch_size, num_heads, sub_seq_length)
# check master and distributed attetion scores
sub_master_a = a[:, rank*sub_seq_length:(rank+1)*sub_seq_length]
assert torch.allclose(sub_a, sub_master_a, rtol=1e-5, atol=1e-2)
# run master backward
a.retain_grad()
a.mean().backward()
# run distributed backward
partial_master_a_grad = a.grad[:, rank*sub_seq_length:(rank+1)*sub_seq_length]
torch.autograd.backward(sub_a, partial_master_a_grad)
# check master and distributed grads
partial_master_q_grad = q.grad[:, rank*sub_seq_length:(rank+1)*sub_seq_length]
assert torch.allclose(sub_q.grad, partial_master_q_grad, rtol=1e-5, atol=1e-2), \
'attention score cannot match'
def run_check_sequence(rank, world_size, port):
# init dist
launch(config=CONFIG,
rank=rank,
world_size=world_size,
host='localhost',
port=port,
backend='nccl')
logger = get_dist_logger()
logger.info('Distributed environment is initialzied.', ranks=[0])
def check_ring_av(rank, world_size):
# params
batch_size = 4
num_heads = 4
seq_length = 16
attention_head_size = 32
sub_seq_length = seq_length // world_size
# check layers
check_layer()
# create master tensors
a = torch.rand(batch_size*num_heads, seq_length, seq_length).cuda()
v = torch.rand(batch_size*num_heads, seq_length, attention_head_size).cuda()
dist.broadcast(a, src=0, group=gpc.get_group(ParallelMode.SEQUENCE))
dist.broadcast(v, src=0, group=gpc.get_group(ParallelMode.SEQUENCE))
# create distributed tensors
sub_a = a.clone()[:, rank*sub_seq_length:(rank+1)*sub_seq_length].contiguous()
sub_v = v.clone()[:, rank*sub_seq_length:(rank+1)*sub_seq_length].contiguous()
# set autograd attributes
a.requires_grad = True
v.requires_grad = True
a.retain_grad()
v.retain_grad()
sub_a.requires_grad = True
sub_v.requires_grad = True
sub_a.retain_grad()
sub_v.retain_grad()
# compute master attention scores
out = torch.matmul(a, v)
# compute distributed attention scores
ring_av = colossalai.nn.layer.parallel_sequence.RingAV.apply
sub_out = ring_av(sub_a, sub_v, batch_size, num_heads, attention_head_size, sub_seq_length)
# print(f'master output shape: {out.shape}, partial output shape: {sub_out.shape}')
# check master and distributed output
sub_master_out = out[:, rank*sub_seq_length:(rank+1)*sub_seq_length]
assert torch.allclose(sub_out, sub_master_out, rtol=1e-5, atol=1e-2)
# # run master backward
out.retain_grad()
out.mean().backward()
# # run distributed backward
partial_master_out_grad = out.grad[:, rank*sub_seq_length:(rank+1)*sub_seq_length]
torch.autograd.backward(sub_out, partial_master_out_grad)
# # check master and distributed grads
partial_master_a_grad = a.grad[:, rank*sub_seq_length:(rank+1)*sub_seq_length]
assert torch.allclose(sub_a.grad, partial_master_a_grad, rtol=1e-5, atol=1e-2), \
'attention output cannot match'
def run_test(rank, world_size):
colossalai.launch(
rank=rank,
world_size=world_size,
config=CONFIG,
host='localhost',
port=29500
)
# check_ring_qk(rank, world_size)
check_ring_av(rank, world_size)
gpc.destroy()
torch.cuda.empty_cache()
@pytest.mark.dist
def test_sequence():
world_size = 4
run_func = partial(run_check_sequence, world_size=world_size, port=free_port())
run_func = partial(run_test, world_size=world_size)
mp.spawn(run_func, nprocs=world_size)