Merge branch 'main' of https://github.com/hpcaitech/ColossalAI into rlhf_SimPO

pull/5850/head
YeAnbang 2024-06-28 02:50:14 +00:00
commit e7527762a1
34 changed files with 952 additions and 514 deletions

View File

@ -9,6 +9,7 @@
<a href="https://www.colossalai.org/"> Documentation </a> |
<a href="https://github.com/hpcaitech/ColossalAI/tree/main/examples"> Examples </a> |
<a href="https://github.com/hpcaitech/ColossalAI/discussions"> Forum </a> |
<a href="https://cloud.luchentech.com/">GPU Cloud Playground </a> |
<a href="https://hpc-ai.com/blog"> Blog </a></h3>
[![GitHub Repo stars](https://img.shields.io/github/stars/hpcaitech/ColossalAI?style=social)](https://github.com/hpcaitech/ColossalAI/stargazers)
@ -132,6 +133,8 @@ distributed training and inference in a few lines.
[[blog]](https://hpc-ai.com/blog/open-sora-from-hpc-ai-tech-team-continues-open-source-generate-any-16-second-720p-hd-video-with-one-click-model-weights-ready-to-use)
[[Model weights]](https://github.com/hpcaitech/Open-Sora?tab=readme-ov-file#model-weights)
[[Demo]](https://github.com/hpcaitech/Open-Sora?tab=readme-ov-file#-latest-demo)
[[GPU Cloud Playground]](https://cloud.luchentech.com/)
[[OpenSora Image]](https://cloud.luchentech.com/doc/docs/image/open-sora/)
<div align="center">
<a href="https://youtu.be/ilMQpU71ddI?si=J4JSPzZ03ycYmlki">
@ -143,6 +146,9 @@ distributed training and inference in a few lines.
### Colossal-LLaMA-2
[[GPU Cloud Playground]](https://cloud.luchentech.com/)
[[LLaMA3 Image]](https://cloud.luchentech.com/doc/docs/image/llama)
- 7B: One half-day of training using a few hundred dollars yields similar results to mainstream large models, open-source and commercial-free domain-specific LLM solution.
[[code]](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Colossal-LLaMA-2)
[[blog]](https://www.hpc-ai.tech/blog/one-half-day-of-training-using-a-few-hundred-dollars-yields-similar-results-to-mainstream-large-models-open-source-and-commercial-free-domain-specific-llm-solution)
@ -275,6 +281,8 @@ Acceleration of [AlphaFold Protein Structure](https://alphafold.ebi.ac.uk/)
- 70 billion parameter LLaMA3 model training accelerated by 18%
[[code]](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/llama)
[[GPU Cloud Playground]](https://cloud.luchentech.com/)
[[LLaMA3 Image]](https://cloud.luchentech.com/doc/docs/image/llama)
### LLaMA2
<p align="center">
@ -385,6 +393,8 @@ Please visit our [documentation](https://www.colossalai.org/) and [examples](htt
- Large AI models inference speed doubled, compared to the offline inference performance of vLLM in some cases.
[[code]](https://github.com/hpcaitech/ColossalAI/tree/main/colossalai/inference)
[[blog]](https://hpc-ai.com/blog/colossal-inference)
[[GPU Cloud Playground]](https://cloud.luchentech.com/)
[[LLaMA3 Image]](https://cloud.luchentech.com/doc/docs/image/llama)
### Grok-1
<p id="Grok-1" align="center">

View File

@ -2,6 +2,12 @@
<h1>
Colossal-LLaMA
</h1>
<h3>
<a href="https://cloud.luchentech.com/">GPU Cloud Playground </a> </a> |
<a href="https://cloud.luchentech.com/doc/docs/image/llama"> LLaMA3 Image </a>
</h3>
</div>
## Table of Contents

View File

@ -2,6 +2,12 @@
<h1>
<img src="https://github.com/hpcaitech/public_assets/blob/main/applications/colossal-llama-2/colossaleval.jpg?raw=true" width=800/>
</h1>
<h3>
<a href="https://cloud.luchentech.com/">GPU Cloud Playground </a> </a> |
<a href="https://cloud.luchentech.com/doc/docs/image/colossal-eval"> Colossal-Eval Image </a>
</h3>
</div>
## Table of Contents

View File

@ -2,6 +2,15 @@
This directory contains the applications that are powered by Colossal-AI.
<div align="center">
<h3>
<a href="https://cloud.luchentech.com/">GPU Cloud Playground </a> </a> |
<a href="https://cloud.luchentech.com/doc/docs/intro"> Playground Document </a>
</h3>
</div>
The list of applications include:
- [X] [Open-Sora](https://github.com/hpcaitech/Open-Sora): Revealing Complete Model Parameters, Training Details, and Everything for Sora-like Video Generation Models

View File

@ -369,9 +369,9 @@ class GeminiPlugin(DPPluginBase):
assert precision in SUPPORTED_PRECISION, f"precision {precision} is not supported"
if get_accelerator().name == "npu":
assert placement_policy == "static", "NPU only supports static placement policy"
if placement_policy == "auto" and enable_async_reduce:
if enable_async_reduce and not pin_memory:
logging.warning(
f"enable_async_reduce requires pin_memory to achieve best performance, which is not implicitly set."
f"enable_async_reduce sets pin_memory=True to achieve best performance, which is not implicitly set."
)
pin_memory = True
self.gemini_config = dict(

View File

@ -946,7 +946,7 @@ class HybridParallelPlugin(PipelinePluginBase):
gradient_checkpoint_config (GradientCheckpointConfig, optional): Configuration for gradient checkpointing. Defaults to None.
enable_metadata_cache (bool, optional): Whether to enable metadata cache for pipeline parallelism. Defaults to True.
make_vocab_size_divisible_by (int, optional): it's used when padding the vocabulary size, to make it choose an faster kenel. Default to 64.
overlap_p2p (bool, optional): Whether to overlap the p2p communication in pipeline parallelism
"""
def __init__(
@ -992,6 +992,7 @@ class HybridParallelPlugin(PipelinePluginBase):
enable_metadata_cache: bool = True,
make_vocab_size_divisible_by: int = 64,
dp_outside: bool = True,
overlap_p2p: bool = True,
) -> None:
super().__init__()
assert (
@ -1062,7 +1063,9 @@ class HybridParallelPlugin(PipelinePluginBase):
assert (
num_microbatches is not None or microbatch_size is not None
), "num_microbatches or microbatch_size must be specified when using pipeline parallelism"
assert self.zero_stage <= 1, "zero stage must be 0 or 1 when using pipeline parallelism"
assert (
self.zero_stage <= 1
), "To avoid prohibitive gradient synchronization costs, zero stage must be 0 or 1 when using pipeline parallelism"
self.stage_manager = PipelineStageManager(
self.pg_mesh,
pipeline_axis=self.pp_axis,
@ -1079,6 +1082,7 @@ class HybridParallelPlugin(PipelinePluginBase):
num_microbatch=num_microbatches,
microbatch_size=microbatch_size,
enable_metadata_cache=enable_metadata_cache,
overlap_p2p=overlap_p2p,
)
elif pp_style == "1f1b":
self.schedule = OneForwardOneBackwardSchedule(

View File

@ -134,7 +134,7 @@ class ProcessGroupMesh:
"""
assert mode in ["raise", "wrap", "clip"]
return np.ravel_multi_index(coord, shape, mode)
return int(np.ravel_multi_index(coord, shape, mode))
def get_group(self, ranks_in_group: List[int], backend: Optional[str] = None) -> ProcessGroup:
"""Get the process group with the given ranks. It the process group doesn't exist, it will be created.
@ -182,7 +182,7 @@ class ProcessGroupMesh:
axis = [
axis,
]
assert isinstance(indices_at_axis[0], int)
assert isinstance(indices_at_axis[0], int), f"Expected int, but got {type(indices_at_axis[0])}."
indices_at_axis = [
indices_at_axis,
]

View File

@ -24,8 +24,9 @@ from colossalai.inference.modeling.policy import model_policy_map
from colossalai.inference.sampler import search_tokens
from colossalai.inference.spec import Drafter, GlideInput
from colossalai.inference.struct import Sequence
from colossalai.inference.utils import get_model_size
from colossalai.inference.utils import get_model_size, has_index_file
from colossalai.interface import ModelWrapper
from colossalai.lazy import LazyInitContext
from colossalai.logging import get_dist_logger
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer
@ -122,16 +123,24 @@ class InferenceEngine:
model_inference_config: the configuration for modeling initialization when inference.
model_shard_infer_config (ModelShardInferenceConfig): the configuration for init of module when inference.
"""
pretrained_path = None
if isinstance(model_or_path, str):
import colossalai.interface.pretrained as pretrained_utils
try:
hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True)
hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True, torch_dtype=self.dtype)
arch = getattr(hf_config, "architectures")[0]
if arch in _supported_models.keys():
# NOTE(lry89757) Currently we load the model using transformers-api,
# but we will use lazy tensor and checkpoint io to accelerate
# the model load process in the future.
model = _supported_models[arch].from_pretrained(model_or_path, trust_remote_code=True)
if arch is "BaichuanForCausalLM":
self.logger.warning(
"Attention ! We use lazy init by default, which could be faster for model loading. For baichuan model, the output maybe have a slight difference with transformers"
)
ctx = LazyInitContext(default_device="cuda")
with ctx:
model = _supported_models[arch].from_pretrained(
model_or_path, trust_remote_code=True, torch_dtype=self.dtype
)
pretrained_path = pretrained_utils.get_pretrained_path(model)
else:
# TODO(char-1ee): if the model not supported, use transformers APIs to load and generate
raise ValueError(f"Model {arch} is not supported.")
@ -189,14 +198,13 @@ class InferenceEngine:
f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}"
)
# NOTE(lry89757) Deprecated currently, will reused when introduce lazy tensor
# if isinstance(model_or_path, str) and not isinstance(casuallm, AutoModelForCausalLM):
# from colossalai.inference.core.plugin import InferCheckpoint_io
if pretrained_path:
from colossalai.inference.core.plugin import InferCheckpoint_io
# cpt_io = InferCheckpoint_io()
# if_has_index_file, model_index_file = has_index_file(model_or_path)
# assert if_has_index_file, "the model path is invalid"
# cpt_io.load_model(self.model, model_index_file)
cpt_io = InferCheckpoint_io()
if_has_index_file, model_index_file = has_index_file(pretrained_path)
assert if_has_index_file, "the model path is invalid"
cpt_io.load_model(self.model, model_index_file)
free_gpu_memory, _ = torch.cuda.mem_get_info()
peak_memory = init_gpu_memory - free_gpu_memory

View File

@ -73,7 +73,9 @@ class RPCInferenceEngine(InferenceEngine):
try:
if isinstance(model_or_path, str):
self.model_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True)
self.model_config = AutoConfig.from_pretrained(
model_or_path, trust_remote_code=True, torch_dtype=self.dtype
)
elif isinstance(model_or_path, nn.Module):
self.logger.error(
f"An exception occurred during loading model Config: For {__class__.__name__}, we don't support param like nn.Module currently\n"

View File

@ -18,8 +18,9 @@ from colossalai.inference.modeling.policy import (
model_policy_map,
)
from colossalai.inference.sampler import search_tokens
from colossalai.inference.utils import get_model_size
from colossalai.inference.utils import get_model_size, has_index_file
from colossalai.interface import ModelWrapper
from colossalai.lazy import LazyInitContext
from colossalai.logging import get_dist_logger
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer
@ -178,20 +179,23 @@ class rpcWorkerService(rpyc.Service):
model_policy (Policy): the policy to replace the model
"""
pretrained_path = None
if isinstance(model_or_path, str):
# is_local = os.path.isdir(model_or_path)
import colossalai.interface.pretrained as pretrained_utils
try:
hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True)
hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True, torch_dtype=self.dtype)
arch = getattr(hf_config, "architectures")[0]
# NOTE(lry89757) Currently we load the model using transformers-api,
# but we will use lazy tensor and checkpoint io to accelerate
# the model load process in the future.
model = _SUPPORTED_MODELS[arch].from_pretrained(model_or_path, trust_remote_code=True)
# if is_local:
# model = _SUPPORTED_MODELS[arch](hf_config)
# else:
# # load the real checkpoint
# model = _SUPPORTED_MODELS[arch].from_pretrained(model_or_path, trust_remote_code=True)
if arch is "BaichuanForCausalLM":
self.logger.warning(
"Attention ! We use lazy init by default, which could be faster for model loading. For baichuan model, the output maybe have a slight difference with transformers"
)
ctx = LazyInitContext(default_device="cuda")
with ctx:
model = _SUPPORTED_MODELS[arch].from_pretrained(
model_or_path, trust_remote_code=True, torch_dtype=self.dtype
)
pretrained_path = pretrained_utils.get_pretrained_path(model)
except Exception as e:
logger.error(
f"An exception occurred during loading model: {e}, model should be loaded by transformers\n"
@ -240,14 +244,13 @@ class rpcWorkerService(rpyc.Service):
f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}"
)
# NOTE(lry89757) Deprecated currently, will reused when introduce lazy tensor
# if isinstance(model_or_path, str) and is_local:
# from colossalai.inference.core.plugin import InferCheckpoint_io
if pretrained_path:
from colossalai.inference.core.plugin import InferCheckpoint_io
# cpt_io = InferCheckpoint_io()
# if_has_index_file, model_index_file = has_index_file(model_or_path)
# assert if_has_index_file, "the model path is invalid"
# cpt_io.load_model(self.model, model_index_file)
cpt_io = InferCheckpoint_io()
if_has_index_file, model_index_file = has_index_file(pretrained_path)
assert if_has_index_file, "the model path is invalid"
cpt_io.load_model(self.model, model_index_file)
free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
peak_memory = init_gpu_memory - free_gpu_memory

View File

@ -1,8 +1,10 @@
from typing import List, Union
import torch.distributed as dist
import torch.nn as nn
from torch.distributed import ProcessGroup
from colossalai.lazy import LazyInitContext
from colossalai.shardformer.layer import Linear1D_Col
from colossalai.shardformer.layer.parallel_module import ParallelModule
@ -12,17 +14,51 @@ class BaichuanLMHeadLinear1D_Col(Linear1D_Col):
def from_native_module(
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
) -> ParallelModule:
LazyInitContext.materialize(module)
module.in_features = module.weight.size(1)
module.out_features = module.weight.size(0)
module.bias = None
module.weight.data = nn.functional.normalize(
module.weight
) # TODO(lry89757) This behavior may not apply to lazy init. When we use lazy init, the weight of shardformer is not the real weight.
) # NOTE(lry89757) This behavior may not apply to lazy init. When we use lazy init, the weight of shardformer is not the real weight.
# So we should rewrite our own load_from_state_dict of `BaichuanLMHeadLinear1D_Col` to fix this potential issue.
return Linear1D_Col.from_native_module(
module,
process_group,
*args,
# get the attributes
in_features = module.in_features
out_features = module.out_features
bias = module.bias is not None
device = module.weight.device
# ensure only one process group is passed
if isinstance(process_group, (list, tuple)):
assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
process_group = process_group[0]
tp_size = dist.get_world_size(process_group)
if out_features < tp_size:
return module
if out_features % tp_size != 0:
raise ValueError(
f"The size of out_features:{out_features} is not integer multiples of tensor parallel size: {tp_size}!"
)
lmhead_1d = BaichuanLMHeadLinear1D_Col(
in_features=in_features,
out_features=out_features,
bias=bias,
device=device,
process_group=process_group,
weight=module.weight,
bias_=module.bias,
**kwargs,
)
return lmhead_1d
def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
):
state_dict[prefix + "weight"] = nn.functional.normalize(state_dict[prefix + "weight"])
super()._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
)

View File

@ -70,7 +70,6 @@ class NopadBaichuanAttention(ParallelModule):
attn_oproj (Linear1D_Row, optional): The Linear1D_Row o_proj. Defaults to None.
"""
ParallelModule.__init__(self)
self.o_proj = attn_oproj
self.config = config
self.num_heads = num_heads
@ -78,6 +77,7 @@ class NopadBaichuanAttention(ParallelModule):
self.head_dim = self.hidden_size // self.num_heads
self.process_group = process_group
self.W_pack = W_pack
self.o_proj = attn_oproj
self.use_cuda_kernel = model_shard_infer_config.use_cuda_kernel
self.attention_backend = get_attention_backend(model_shard_infer_config)
self.pre_attention_backend = get_pre_attention_backend(model_shard_infer_config)

View File

@ -284,6 +284,10 @@ class NopadLlamaMLP(LlamaMLP, ParallelModule):
self.gate_up_weight = nn.Parameter(
torch.stack([mlp_gproj_w.transpose(0, 1), mlp_uproj_w.transpose(0, 1)], dim=0)
)
self.gate_up_dict = {
"gate_proj.weight": None,
"up_proj.weight": None,
} # used and delattr in load/shard of gate/up weight
self.down_proj = mlp_dproj
self.process_group = process_group
@ -321,44 +325,47 @@ class NopadLlamaMLP(LlamaMLP, ParallelModule):
):
# NOTE This is a hack to ensure we could load the right weight from LlamaMLP checkpoint due to the use of torch.stack(gate_weight, up_weight)
for hook in self._load_state_dict_pre_hooks.values():
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
if hasattr(self, "gate_up_dict"):
for hook in self._load_state_dict_pre_hooks.values():
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
local_state = {k: v for k, v in local_name_params if v is not None}
persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
local_state = {k: v for k, v in local_name_params if v is not None}
key = "gate_up_weight"
k1 = "gate_proj.weight"
k2 = "up_proj.weight"
device_mesh = self.helper_layout.device_mesh
sharding_spec = self.helper_layout.sharding_spec
for weight_name in self.gate_up_dict:
prefix_weight_name = prefix + weight_name
if prefix_weight_name in state_dict.keys():
w = distribute_tensor(state_dict[prefix_weight_name], device_mesh, sharding_spec)
self.gate_up_dict[weight_name] = w.T
gate_w = state_dict[prefix + k1]
up_w = state_dict[prefix + k2]
if None not in self.gate_up_dict.values():
# we've got all the weights of gate/up
gate_up_w = torch.stack(list(self.gate_up_dict.values()), dim=0)
device_mesh = self.helper_layout.device_mesh
sharding_spec = self.helper_layout.sharding_spec
gate_w = distribute_tensor(gate_w, device_mesh, sharding_spec)
up_w = distribute_tensor(up_w, device_mesh, sharding_spec)
input_param = nn.Parameter(
gate_up_w
) # NOTE gate_up_weight doesn't have to be a distensor, Like input_param = sharded_tensor_to_param(input_param)
gate_up_w = torch.stack([gate_w.T, up_w.T], dim=0)
key = "gate_up_weight"
param = local_state.get(key, None)
input_param = nn.Parameter(
gate_up_w
) # NOTE gate_up_weight doesn't have to be a distensor, Like input_param = sharded_tensor_to_param(input_param)
param = local_state[key]
try:
with torch.no_grad():
param.copy_(input_param)
except Exception as ex:
error_msgs.append(
'While copying the parameter named "{}", '
"whose dimensions in the model are {} and "
"whose dimensions in the checkpoint are {}, "
"an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args)
)
try:
with torch.no_grad():
param.copy_(input_param)
except Exception as ex:
error_msgs.append(
'While copying the parameter named "{}", '
"whose dimensions in the model are {} and "
"whose dimensions in the checkpoint are {}, "
"an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args)
)
del self.gate_up_dict
strict = False # to avoid unexpected_keys
strict = False # to avoid unexpected_keys
super()._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
)
@ -429,7 +436,15 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule):
self.helper_layout = (
attn_qproj_w.dist_layout
) # NOTE this is a hack for the right load/shard of qkv_weight(used in _load_from_state_dict)
self.qkv_dict = {
"q_proj.weight": None,
"k_proj.weight": None,
"v_proj.weight": None,
} # used and delattr in load/shard of qkv weight
else:
self.helper_layout = (
attn_qproj_w.dist_layout
) # NOTE this is a hack for the right load/shard of qkv_weight(used in _load_from_state_dict)
self.q_proj_weight = nn.Parameter(attn_qproj_w.transpose(0, 1).contiguous())
self.k_proj_weight = nn.Parameter(attn_kproj_w.transpose(0, 1).contiguous())
self.v_proj_weight = nn.Parameter(attn_vproj_w.transpose(0, 1).contiguous())
@ -577,49 +592,83 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule):
def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
):
if self.num_heads == self.num_key_value_heads:
for hook in self._load_state_dict_pre_hooks.values():
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
local_state = {k: v for k, v in local_name_params if v is not None}
device_mesh = self.helper_layout.device_mesh
sharding_spec = self.helper_layout.sharding_spec
if self.num_heads == self.num_key_value_heads and hasattr(self, "qkv_dict"):
# NOTE This is a hack to ensure we could load the right weight from LlamaAttention checkpoint due to the use of torch.stack(q_weight, k_weight, v_weight)
for hook in self._load_state_dict_pre_hooks.values():
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
local_state = {k: v for k, v in local_name_params if v is not None}
key = "qkv_weight"
k1 = "q_proj.weight"
k2 = "k_proj.weight"
k3 = "v_proj.weight"
q_w = state_dict[prefix + k1]
k_w = state_dict[prefix + k2]
v_w = state_dict[prefix + k3]
device_mesh = self.helper_layout.device_mesh
sharding_spec = self.helper_layout.sharding_spec
q_w = distribute_tensor(q_w, device_mesh, sharding_spec)
k_w = distribute_tensor(k_w, device_mesh, sharding_spec)
v_w = distribute_tensor(v_w, device_mesh, sharding_spec)
# NOTE(@lry89757) We will load the sharded checkpoint file according to the weight map from *.index.json
# Here we need the weight of q,k,v to stack the weights of q,k,v into one qkv weight.
# Unfortunately, it is highly like that all weights of q,k,v are not in the same sharded checkpoint file(like meta-llama/llama3-70B)
# so here we will stack them when we really collect all the three weights.
for weight_name in self.qkv_dict:
prefix_weight_name = prefix + weight_name
if prefix_weight_name in state_dict.keys():
w = distribute_tensor(state_dict[prefix_weight_name], device_mesh, sharding_spec)
self.qkv_dict[weight_name] = w.T
qkv_w = torch.stack([q_w.T, k_w.T, v_w.T], dim=0)
if None not in self.qkv_dict.values():
# we've got all the weights of q, k, v
qkv_w = torch.stack(list(self.qkv_dict.values()), dim=0)
input_param = nn.Parameter(
qkv_w
) # NOTE qkv_weight doesn't have to be a distensor, Like input_param = sharded_tensor_to_param(input_param)
input_param = nn.Parameter(
qkv_w
) # NOTE qkv_weight doesn't have to be a distensor, Like input_param = sharded_tensor_to_param(input_param)
param = local_state[key]
param = local_state[key]
try:
with torch.no_grad():
param.copy_(input_param)
except Exception as ex:
error_msgs.append(
'While copying the parameter named "{}", '
"whose dimensions in the model are {} and "
"whose dimensions in the checkpoint are {}, "
"an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args)
)
try:
with torch.no_grad():
param.copy_(input_param)
except Exception as ex:
error_msgs.append(
'While copying the parameter named "{}", '
"whose dimensions in the model are {} and "
"whose dimensions in the checkpoint are {}, "
"an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args)
)
strict = False # to avoid unexpected_keys
del self.qkv_dict
else:
def _load(origin_weight_name="q_proj.weight", local_weight_name="q_proj_weight"):
if prefix + origin_weight_name in state_dict.keys():
attn_qproj_w = state_dict[prefix + origin_weight_name]
w = distribute_tensor(attn_qproj_w, device_mesh, sharding_spec)
input_param = nn.Parameter(w.T)
param = local_state[local_weight_name]
try:
with torch.no_grad():
param.copy_(input_param)
except Exception as ex:
key = local_weight_name
error_msgs.append(
'While copying the parameter named "{}", '
"whose dimensions in the model are {} and "
"whose dimensions in the checkpoint are {}, "
"an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args)
)
if prefix + "q_proj.weight" in state_dict.keys():
_load(origin_weight_name="q_proj.weight", local_weight_name="q_proj_weight")
if prefix + "k_proj.weight" in state_dict.keys():
_load(origin_weight_name="k_proj.weight", local_weight_name="k_proj_weight")
if prefix + "v_proj.weight" in state_dict.keys():
_load(origin_weight_name="v_proj.weight", local_weight_name="v_proj_weight")
strict = False # to avoid unexpected_keys
super()._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
)

View File

@ -225,31 +225,41 @@ def _batch_send_recv_tensor(
send_group: Optional[ProcessGroup],
recv_group: Optional[ProcessGroup],
current_device: Any,
overlap_p2p: bool = True,
send_first: bool = True,
) -> Optional[Union[torch.Tensor, List[torch.Tensor]]]:
buffer_recv = None
if recv_tensor_metadata is not None:
buffer_recv = _create_recv_buffer(recv_tensor_metadata, current_device)
ops = []
if send_dst is not None and send_tensor_list is not None:
assert send_group is not None
_filling_ops_queue(send_tensor_list, dist.isend, send_dst, ops, send_group)
if recv_src is not None and buffer_recv is not None:
assert recv_group is not None
_filling_ops_queue(buffer_recv, dist.irecv, recv_src, ops, recv_group)
is_send = send_dst is not None and send_tensor_list is not None
is_recv = recv_src is not None and buffer_recv is not None
if send_first:
if is_send:
assert send_group is not None
_filling_ops_queue(send_tensor_list, dist.isend, send_dst, ops, send_group)
if is_recv:
assert recv_group is not None
_filling_ops_queue(buffer_recv, dist.irecv, recv_src, ops, recv_group)
else:
if is_recv:
assert recv_group is not None
_filling_ops_queue(buffer_recv, dist.irecv, recv_src, ops, recv_group)
if is_send:
assert send_group is not None
_filling_ops_queue(send_tensor_list, dist.isend, send_dst, ops, send_group)
if len(ops) > 0:
reqs = dist.batch_isend_irecv(ops)
for req in reqs:
req.wait()
# Remove synchronization according to Pytorch's documentation
# However, the Megatron-LM does synchronization here
# https://github.com/microsoft/Megatron-DeepSpeed/blob/ef13d099c2a1609225a4ce4c1a1753cc76dd90a1/megatron/p2p_communication.py#L111-L112
# In case there is potential error, uncomment the following `torch.cuda.synchronize()`
# torch.cuda.synchronize()
return buffer_recv
if not overlap_p2p:
for req in reqs:
req.wait()
return buffer_recv, []
else:
return buffer_recv, reqs
return None, []
def _send_recv_serialization_object(
@ -260,10 +270,11 @@ def _send_recv_serialization_object(
recv_group: Optional[ProcessGroup],
current_device: Any,
is_nccl_backend: bool,
send_first: bool = True,
) -> Optional[P2PMetadata]:
ops = []
send_object_tensor = None
send_object_size_tensor = None
if object is not None and send_dst is not None:
if Version(torch.__version__) >= Version("1.13.0"):
send_object_tensor, send_object_size_tensor = c10d._object_to_tensor(object, device=current_device)
@ -274,43 +285,54 @@ def _send_recv_serialization_object(
send_object_size_tensor = send_object_size_tensor.to(current_device)
send_object_tensor = send_object_tensor.to(current_device)
_filling_ops_queue(send_object_size_tensor, dist.isend, send_dst, ops, send_group)
recv_object_size_tensor = None
if recv_src is not None:
recv_object_size_tensor = torch.empty(1, dtype=torch.long)
if is_nccl_backend:
recv_object_size_tensor = recv_object_size_tensor.to(current_device)
_filling_ops_queue(recv_object_size_tensor, dist.irecv, recv_src, ops, recv_group)
if send_first:
if send_object_size_tensor is not None:
_filling_ops_queue(send_object_size_tensor, dist.isend, send_dst, ops, send_group)
if recv_src is not None:
_filling_ops_queue(recv_object_size_tensor, dist.irecv, recv_src, ops, recv_group)
else:
if recv_src is not None:
_filling_ops_queue(recv_object_size_tensor, dist.irecv, recv_src, ops, recv_group)
if send_object_size_tensor is not None:
_filling_ops_queue(send_object_size_tensor, dist.isend, send_dst, ops, send_group)
if len(ops) > 0:
reqs = dist.batch_isend_irecv(ops)
for req in reqs:
req.wait()
# See the comment in `_batch_send_recv_tensor`
# torch.cuda.synchronize()
req.wait() # This blocks the compute stream in torch
ops = []
if send_dst is not None and send_object_tensor is not None:
_filling_ops_queue(send_object_tensor, dist.isend, send_dst, ops, send_group)
is_send = send_dst is not None and send_object_tensor is not None
is_recv = recv_src is not None and recv_object_size_tensor is not None
recv_object_tensor = None
if recv_src is not None and recv_object_size_tensor is not None:
if is_recv:
recv_object_tensor = torch.empty(recv_object_size_tensor.item(), dtype=torch.uint8)
if is_nccl_backend:
recv_object_tensor = recv_object_tensor.to(current_device)
_filling_ops_queue(recv_object_tensor, dist.irecv, recv_src, ops, recv_group)
if send_first:
if is_send:
_filling_ops_queue(send_object_tensor, dist.isend, send_dst, ops, send_group)
if is_recv:
_filling_ops_queue(recv_object_tensor, dist.irecv, recv_src, ops, recv_group)
else:
if is_recv:
_filling_ops_queue(recv_object_tensor, dist.irecv, recv_src, ops, recv_group)
if is_send:
_filling_ops_queue(send_object_tensor, dist.isend, send_dst, ops, send_group)
if len(ops) > 0:
reqs = dist.batch_isend_irecv(ops)
for req in reqs:
req.wait()
# See the comment in `_batch_send_recv_tensor`
# torch.cuda.synchronize()
if recv_object_tensor is not None and recv_object_size_tensor is not None:
recv_object_tensor = recv_object_tensor.type(torch.uint8)
if recv_object_tensor.device != torch.device("cpu"):
@ -328,11 +350,12 @@ def _communicate(
object: Any,
send_dst: Optional[int],
recv_src: Optional[int],
overlap_p2p: bool,
send_group: Optional[ProcessGroup] = None,
recv_group: Optional[ProcessGroup] = None,
send_metadata: bool = True,
metadata_recv: Optional[P2PMetadata] = None,
send_prior_fallback: Optional[bool] = None,
send_first: Optional[bool] = None,
) -> Any:
"""
Send and receive object from send_dst and recv_src respectively
@ -341,6 +364,7 @@ def _communicate(
object (Any): object needed to be sent
send_dst (int): rank of the destination
recv_src (int): rank of the source
overlap_p2p (bool): whether to overlap p2p communication with computation
send_group (ProcessGroup, optional): process group of sender
recv_group (ProcessGroup, optional): process group of receiver
send_metadata (bool, optional): whether to send metadata
@ -358,32 +382,10 @@ def _communicate(
# NOTE: if object contains non-tensor objects, we have to send metadata
metadata_send, tensor_objs = create_send_metadata(object, strict=False, return_tensor=True)
send_metadata = send_metadata or len(metadata_send.non_tensor_obj_idx) > 0
else:
send_metadata = False
# NOTE: send & recv should be atomic operations. However, if we need to send metadata or receive metadata,
# we are not able to do that (1. send & recv metadata 2. send & recv). So we need to split the send & recv into two parts in this case.
if (send_dst is not None and recv_src is not None) and (send_metadata or metadata_recv is None):
assert send_prior_fallback is not None, "Priority must be set if fallback happens"
if send_prior_fallback:
_communicate(object, send_dst=send_dst, recv_src=None, send_group=send_group, send_metadata=send_metadata)
return _communicate(
None, send_dst=None, recv_src=recv_src, recv_group=recv_group, metadata_recv=metadata_recv
)
else:
recv_data = _communicate(
None, send_dst=None, recv_src=recv_src, recv_group=recv_group, metadata_recv=metadata_recv
)
_communicate(object, send_dst=send_dst, recv_src=None, send_group=send_group, send_metadata=send_metadata)
return recv_data
# NOTE: only the following 5 cases are valid:
# 1. send() [needs extra metadata] and no recv()
# 2. recv() [needs extra metadata] and no send()
# 3. neither send() nor recv() need extra metadata
assert not (send_dst is not None and send_metadata) or recv_src is None
assert not (recv_src is not None and metadata_recv is None) or send_dst is None
assert not (send_dst is not None and recv_src is not None) or (not send_metadata and metadata_recv is not None)
assert not c10d._rank_not_in_group(send_group) and not c10d._rank_not_in_group(recv_group)
current_send_device, is_send_nccl_backend = _check_device(send_group)
current_recv_device, is_recv_nccl_backend = _check_device(recv_group)
@ -402,14 +404,25 @@ def _communicate(
recv_group=recv_group if metadata_recv is None else None,
current_device=current_device,
is_nccl_backend=is_nccl_backend,
send_first=send_first if send_first != None else True,
)
assert metadata_recv is None or _metadata_recv is None
assert (
metadata_recv is None or _metadata_recv is None
), "You shouldn't receive metadata when using the cached metadata"
metadata_recv = _metadata_recv if metadata_recv is None else metadata_recv
# Send and receive data
recv_tensor_metadata = None if metadata_recv is None else metadata_recv.tensor_metadata
recv_tensor_objs = _batch_send_recv_tensor(
tensor_objs, recv_tensor_metadata, send_dst, recv_src, send_group, recv_group, current_device
recv_tensor_objs, wait_handles = _batch_send_recv_tensor(
tensor_objs,
recv_tensor_metadata,
send_dst,
recv_src,
send_group,
recv_group,
current_device,
overlap_p2p=overlap_p2p,
send_first=send_first if send_first != None else True,
)
if metadata_recv is not None:
@ -424,33 +437,9 @@ def _communicate(
for idx in non_tensor_obj_idx:
recv_tensor_objs.insert(idx, non_tensor_objs.pop(0))
recv_object = tree_unflatten(recv_tensor_objs, tree_spec)
return recv_object, wait_handles
return recv_object
def _send_object(object: Any, src: int, dst: int, group: ProcessGroup, **kwargs) -> None:
"""send anything to dst rank
Args:
object (Any): object needed to be sent
dst (int): rank of the destination
Returns:
None
"""
_communicate(object, send_dst=dst, recv_src=None, send_group=group, **kwargs)
def _recv_object(src: int, dst: int, group: ProcessGroup, **kwargs) -> Any:
"""recv anything from src
Args:
src (int): source rank of data. local rank will receive data from src rank.
Returns:
Any: Object received from src.
"""
return _communicate(None, send_dst=None, recv_src=src, recv_group=group, **kwargs)
return None, wait_handles
def _p2p_comm(
@ -532,10 +521,13 @@ def _p2p_comm(
class PipelineP2PCommunication:
def __init__(self, stage_manager: PipelineStageManager) -> None:
def __init__(self, stage_manager: PipelineStageManager, overlap_p2p: bool = True) -> None:
self.stage_manager = stage_manager
self.overlap_p2p = overlap_p2p
def recv_forward(self, prev_rank: Optional[int] = None, metadata_recv: Optional[P2PMetadata] = None) -> Any:
def recv_forward(
self, prev_rank: Optional[int] = None, metadata_recv: Optional[P2PMetadata] = None
) -> Tuple[Any, List]:
"""Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
Args:
@ -543,95 +535,186 @@ class PipelineP2PCommunication:
Returns:
Any: The input tensor or input tensor list.
List: List of handles for the communication requests, if overlap is enabled.
"""
if prev_rank is None:
prev_rank = self.stage_manager.get_prev_rank()
cur_rank = self.stage_manager.get_rank()
input_tensor = _recv_object(
prev_rank,
cur_rank,
self.stage_manager.get_p2p_process_group(prev_rank, cur_rank),
input_tensor, wait_handles = _communicate(
object=None,
recv_src=prev_rank,
send_dst=None,
recv_group=self.stage_manager.get_p2p_process_group(),
metadata_recv=metadata_recv,
overlap_p2p=self.overlap_p2p,
)
return input_tensor
return input_tensor, wait_handles
def recv_backward(self, next_rank: Optional[int] = None, metadata_recv: Optional[P2PMetadata] = None) -> Any:
def recv_backward(
self, next_rank: Optional[int] = None, metadata_recv: Optional[P2PMetadata] = None
) -> Tuple[Any, List]:
"""Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
Args:
next_rank (int, optional): The rank of the source of the tensor.
Returns:
Any: The input gradient tensor or gradient tensor list.
Any: The input tensor or input tensor list.
List: List of handles for the communication requests, if overlap is enabled.
"""
if next_rank is None:
next_rank = self.stage_manager.get_next_rank()
cur_rank = self.stage_manager.get_rank()
output_tensor_grad = _recv_object(
next_rank,
cur_rank,
self.stage_manager.get_p2p_process_group(next_rank, cur_rank),
output_tensor_grad, wait_handles = _communicate(
object=None,
recv_src=next_rank,
send_dst=None,
recv_group=self.stage_manager.get_p2p_process_group(),
metadata_recv=metadata_recv,
overlap_p2p=self.overlap_p2p,
)
return output_tensor_grad
return output_tensor_grad, wait_handles
def send_forward(self, output_object: Any, next_rank: Optional[int] = None, send_metadata: bool = True) -> None:
def send_forward(self, output_object: Any, next_rank: Optional[int] = None, send_metadata: bool = True) -> List:
"""Sends the input tensor to the next stage in pipeline.
Args:
output_object (Any): Object to be sent.
next_rank (int, optional): The rank of the recipient of the tensor.
Returns:
List: List of handles for the communication requests, if overlap is enabled.
"""
if next_rank is None:
next_rank = self.stage_manager.get_next_rank()
cur_rank = self.stage_manager.get_rank()
_send_object(
_, handles = _communicate(
output_object,
cur_rank,
next_rank,
self.stage_manager.get_p2p_process_group(cur_rank, next_rank),
recv_src=None,
send_dst=next_rank,
send_group=self.stage_manager.get_p2p_process_group(),
send_metadata=send_metadata,
overlap_p2p=self.overlap_p2p,
)
return handles
def send_backward(self, input_object: Any, prev_rank: Optional[int] = None, send_metadata: bool = True) -> None:
def send_backward(self, input_object: Any, prev_rank: Optional[int] = None, send_metadata: bool = True) -> List:
"""Sends the gradient tensor to the previous stage in pipeline.
Args:
input_object (Any): Object to be sent.
prev_rank (int, optional): The rank of the recipient of the tensor
Returns:
List: List of handles for the communication requests, if overlap is enabled.
"""
if prev_rank is None:
prev_rank = self.stage_manager.get_prev_rank()
cur_rank = self.stage_manager.get_rank()
_send_object(
_, handles = _communicate(
input_object,
cur_rank,
prev_rank,
self.stage_manager.get_p2p_process_group(cur_rank, prev_rank),
recv_src=None,
send_dst=prev_rank,
send_group=self.stage_manager.get_p2p_process_group(),
send_metadata=send_metadata,
overlap_p2p=self.overlap_p2p,
)
return handles
def send_forward_recv_forward(
self,
output_object: Any,
is_send: bool,
is_recv: bool,
send_first: bool,
send_metadata: bool = True,
metadata_recv: Optional[P2PMetadata] = None,
) -> Tuple[Any, List]:
"""Sends the input tensor to the next pipeline stage and copy the output tensor from the next pipeline stage
Args:
output_object (Any): Object to be sent.
is_send (bool): Whether to send the input tensor to the next pipeline stage.
is_recv (bool): Whether to copy the output tensor from the next pipeline stage.
send_first (bool): Whether to send before receive.
send_metadata (bool, optional): Whether to send metadata.
metadata_recv (P2PMetadata, optional): The cached metadata(size, type) of the object to be received.
Returns:
Any: The input tensor or input tensor list.
List: List of handles for the communication requests, if overlap is enabled.
"""
next_rank = self.stage_manager.get_next_rank() if is_send else None
prev_rank = self.stage_manager.get_prev_rank() if is_recv else None
group = self.stage_manager.get_p2p_process_group()
return _communicate(
output_object,
send_dst=next_rank,
recv_src=prev_rank,
send_group=group if is_send else None,
recv_group=group if is_recv else None,
send_metadata=send_metadata if is_send else False,
metadata_recv=metadata_recv if is_recv else None,
send_first=send_first,
overlap_p2p=self.overlap_p2p,
)
def send_backward_recv_backward(
self,
input_object: Any,
is_send: bool,
is_recv: bool,
send_first: bool,
send_metadata: bool = True,
metadata_recv: Optional[P2PMetadata] = None,
) -> Tuple[Any, List]:
"""Sends the gradient tensor to the previous pipeline stage and copy the gradient tensor from the previous pipeline stage
Args:
input_object (Any): Object to be sent.
is_send (bool): Whether to send the gradient tensor to the previous pipeline stage.
is_recv (bool): Whether to copy the gradient tensor from the previous pipeline stage.
send_first (bool): Whether to send before receive.
send_metadata (bool, optional): Whether to send metadata.
metadata_recv (P2PMetadata, optional): The cached metadata(size, type) of the object to be received.
Returns:
Any: The input tensor or input tensor list.
List: List of handles for the communication requests, if overlap is enabled.
"""
prev_rank = self.stage_manager.get_prev_rank() if is_send else None
next_rank = self.stage_manager.get_next_rank() if is_recv else None
group = self.stage_manager.get_p2p_process_group()
return _communicate(
input_object,
send_dst=prev_rank,
recv_src=next_rank,
send_group=group if is_send else None,
recv_group=group if is_recv else None,
send_metadata=send_metadata if is_send else False,
metadata_recv=metadata_recv if is_recv else None,
send_first=send_first,
overlap_p2p=self.overlap_p2p,
)
def send_forward_recv_backward(
self,
input_object: Any,
next_rank: Optional[int] = None,
send_metadata: bool = True,
metadata_recv: Optional[P2PMetadata] = None,
send_prior_fallback: Optional[bool] = None,
) -> Any:
"""Sends the gradient tensor to and copy the gradient tensor from the next stage in pipeline
send_first: Optional[bool] = None,
) -> Tuple[Any, List]:
"""Sends the gradient tensor to and copy the gradient tensor from the next pipeline stage
Args:
input_object (Any): Object to be sent.
next_rank (int, optional): The rank of the sender and recipient of the tensor
"""
if next_rank is None:
next_rank = self.stage_manager.get_next_rank()
cur_rank = self.stage_manager.get_rank()
group = self.stage_manager.get_p2p_process_group(cur_rank, next_rank)
Returns:
Any: The input tensor or input tensor list.
List: List of handles for the communication requests, if overlap is enabled.
"""
next_rank = self.stage_manager.get_next_rank()
group = self.stage_manager.get_p2p_process_group()
return _communicate(
input_object,
next_rank,
@ -640,28 +723,28 @@ class PipelineP2PCommunication:
recv_group=group,
send_metadata=send_metadata,
metadata_recv=metadata_recv,
send_prior_fallback=send_prior_fallback,
send_first=send_first,
overlap_p2p=False,
)
def send_backward_recv_forward(
self,
input_object: Any,
prev_rank: Optional[int] = None,
send_metadata: bool = True,
metadata_recv: Optional[P2PMetadata] = None,
send_prior_fallback: Optional[bool] = None,
) -> Any:
send_first: Optional[bool] = None,
) -> Tuple[Any, List]:
"""Sends the gradient tensor to and copy the gradient tensor from the previous stage in pipeline
Args:
input_object (Any): Object to be sent.
prev_rank (int, optional): The rank of the sender and recipient of the tensor
"""
if prev_rank is None:
prev_rank = self.stage_manager.get_prev_rank()
cur_rank = self.stage_manager.get_rank()
group = self.stage_manager.get_p2p_process_group(prev_rank, cur_rank)
Returns:
Any: The input tensor or input tensor list.
List: List of handles for the communication requests, if overlap is enabled.
"""
prev_rank = self.stage_manager.get_prev_rank()
group = self.stage_manager.get_p2p_process_group()
return _communicate(
input_object,
prev_rank,
@ -670,7 +753,8 @@ class PipelineP2PCommunication:
recv_group=group,
send_metadata=send_metadata,
metadata_recv=metadata_recv,
send_prior_fallback=send_prior_fallback,
send_first=send_first,
overlap_p2p=False,
)
def p2p_communicate(
@ -679,7 +763,7 @@ class PipelineP2PCommunication:
recv_pre: bool,
next_rank: Optional[int] = None,
comm_dtype: torch.dtype = torch.float16,
) -> None:
) -> Any:
"""
Sends the input tensor to the next stage in pipeline, using `P2Pop` in torch.
@ -689,12 +773,11 @@ class PipelineP2PCommunication:
"""
if next_rank is None:
next_rank = self.stage_manager.get_next_rank()
cur_rank = self.stage_manager.get_rank()
recv_tensor = _p2p_comm(
output_object,
recv_pre,
next_rank,
self.stage_manager.get_p2p_process_group(cur_rank, next_rank),
self.stage_manager.get_p2p_process_group(),
comm_dtype,
)
return recv_tensor

View File

@ -1,8 +1,9 @@
from functools import partial
from typing import Any, Callable, Dict, Iterable, List, Optional, Union
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
import torch
import torch.cuda
import torch.distributed
from torch.nn import Module, ModuleList
from torch.utils._pytree import tree_map
@ -16,6 +17,12 @@ from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_
from .base import PipelineSchedule
def _wait_p2p(wait_handles: List[torch.cuda.Event]) -> None:
if wait_handles is not None:
for req in wait_handles:
req.wait()
class InterleavedSchedule(PipelineSchedule):
def __init__(
self,
@ -24,13 +31,15 @@ class InterleavedSchedule(PipelineSchedule):
num_microbatch: Optional[int] = None,
microbatch_size: Optional[int] = None,
enable_metadata_cache: bool = True,
overlap_p2p: bool = True,
) -> None:
super().__init__(stage_manager)
assert (
num_microbatch is not None or microbatch_size is not None
), "Either num_microbatch or microbatch_size should be provided"
self.comm = PipelineP2PCommunication(stage_manager)
self.comm = PipelineP2PCommunication(stage_manager, overlap_p2p=overlap_p2p)
self.overlap_p2p = overlap_p2p
self.num_microbatch = num_microbatch
self.microbatch_size = microbatch_size
self.num_model_chunks = num_model_chunks
@ -113,14 +122,17 @@ class InterleavedSchedule(PipelineSchedule):
Returns:
int: The model chunk idx of the input microbatch_id
"""
assert microbatch_id < self.num_microbatch * self.num_model_chunks
assert (
microbatch_id < self.num_microbatch * self.num_model_chunks
), f"microbatch_id {microbatch_id} is out of range ({self.num_microbatch * self.num_model_chunks})"
microbatch_id_in_group = microbatch_id % (self.stage_manager.num_stages * self.num_model_chunks)
model_chunk_id = microbatch_id_in_group // self.stage_manager.num_stages
if not is_forward:
# Reverse order
model_chunk_id = self.num_model_chunks - model_chunk_id - 1
return model_chunk_id
def recv_forward(self, model_chunk_id: int, prev_rank: int = None) -> Any:
def recv_forward(self, model_chunk_id: int, prev_rank: int = None) -> Tuple[Any, List]:
"""Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
For interleaved 1F1B.
@ -130,16 +142,19 @@ class InterleavedSchedule(PipelineSchedule):
Returns:
Any: The input tensor or input tensor list.
Any: The wait handles for the communication.
"""
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
if not self.stage_manager.is_first_stage():
input_tensor = self.comm.recv_forward(prev_rank, metadata_recv=self.tensor_metadata_recv)
input_tensor, wait_handles = self.comm.recv_forward(prev_rank, metadata_recv=self.tensor_metadata_recv)
if self.enable_metadata_cache and self.tensor_metadata_recv is None:
self.tensor_metadata_recv = create_send_metadata(input_tensor)
return input_tensor
return input_tensor, wait_handles
return None, []
def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> Any:
def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> Tuple[Any, List]:
"""Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
For interleaved 1F1B.
@ -149,16 +164,20 @@ class InterleavedSchedule(PipelineSchedule):
Returns:
Any: The input gradient tensor or gradient tensor list.
Any: The wait handles for the communication.
"""
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
if not self.stage_manager.is_last_stage():
output_tensor_grad = self.comm.recv_backward(next_rank, metadata_recv=self.grad_metadata_recv)
output_tensor_grad, wait_handles = self.comm.recv_backward(
next_rank, metadata_recv=self.grad_metadata_recv
)
if self.enable_metadata_cache and self.grad_metadata_recv is None:
self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
return output_tensor_grad, wait_handles
return output_tensor_grad
return None, []
def send_forward(self, model_chunk_id: int, output_tensor: Any, next_rank: int = None) -> None:
def send_forward(self, model_chunk_id: int, output_tensor: Any, next_rank: int = None) -> List:
"""Sends the input tensor to the next stage in pipeline.
For interleaved 1F1B.
@ -166,13 +185,18 @@ class InterleavedSchedule(PipelineSchedule):
model_chunk_id (int): The current model chunk idx.
output_object (Any): Object to be sent.
next_rank (int, optional): The rank of the recipient of the tensor.
Returns:
Any: The wait handles for the communication.
"""
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
if not self.stage_manager.is_last_stage():
self.comm.send_forward(output_tensor, next_rank, send_metadata=self.send_tensor_metadata)
send_handles = self.comm.send_forward(output_tensor, next_rank, send_metadata=self.send_tensor_metadata)
self.send_tensor_metadata = not self.enable_metadata_cache
return send_handles
return []
def send_backward(self, model_chunk_id: int, input_tensor_grad: Any, prev_rank: int = None) -> None:
def send_backward(self, model_chunk_id: int, input_tensor_grad: Any, prev_rank: int = None) -> List:
"""Sends the gradient tensor to the previous stage in pipeline.
For interleaved 1F1B.
@ -180,99 +204,61 @@ class InterleavedSchedule(PipelineSchedule):
model_chunk_id (int): The current model chunk idx.
input_object (Any): Object to be sent.
prev_rank (int, optional): The rank of the recipient of the tensor
Returns:
Any: The wait handles for the communication.
"""
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
if not self.stage_manager.is_first_stage():
self.comm.send_backward(input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata)
send_handles = self.comm.send_backward(
input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata
)
self.send_grad_metadata = not self.enable_metadata_cache
def send_forward_recv_backward(
self,
model_chunk_id_send: int,
model_chunk_id_recv: int,
output_tensor: Any,
next_rank: Optional[int] = None,
send_prior_fallback: Optional[bool] = None,
) -> Any:
with self.stage_manager.switch_model_chunk_id(model_chunk_id_send):
send_data = not self.stage_manager.is_last_stage()
with self.stage_manager.switch_model_chunk_id(model_chunk_id_recv):
recv_data = not self.stage_manager.is_last_stage()
if send_data and recv_data:
if not self.send_forward_recv_backward and self.grad_metadata_recv is not None:
send_prior_fallback = None # must not fallback
output_tensor_grad = self.comm.send_forward_recv_backward(
output_tensor,
next_rank,
send_metadata=self.send_tensor_metadata,
metadata_recv=self.grad_metadata_recv,
send_prior_fallback=send_prior_fallback,
)
self.send_tensor_metadata = not self.enable_metadata_cache
if self.enable_metadata_cache and self.grad_metadata_recv is None:
self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
return output_tensor_grad
# send only or recv only
self.send_forward(model_chunk_id_send, output_tensor)
return self.recv_backward(model_chunk_id_recv)
def send_backward_recv_forward(
self,
model_chunk_id_send: int,
model_chunk_id_recv: int,
input_tensor_grad: Any,
prev_rank: Optional[int] = None,
send_prior_fallback: Optional[bool] = None,
) -> Any:
with self.stage_manager.switch_model_chunk_id(model_chunk_id_send):
send_data = not self.stage_manager.is_first_stage()
with self.stage_manager.switch_model_chunk_id(model_chunk_id_recv):
recv_data = not self.stage_manager.is_first_stage()
if send_data and recv_data:
if not self.send_backward_recv_backward and self.tensor_metadata_recv is not None:
send_prior_fallback = None # must not fallback
input_tensor = self.comm.send_backward_recv_forward(
input_tensor_grad,
prev_rank,
send_metadata=self.send_grad_metadata,
metadata_recv=self.tensor_metadata_recv,
send_prior_fallback=send_prior_fallback,
)
self.send_grad_metadata = not self.enable_metadata_cache
if self.enable_metadata_cache and self.tensor_metadata_recv is None:
self.tensor_metadata_recv = create_send_metadata(input_tensor)
return input_tensor
# send only or recv only
self.send_backward(model_chunk_id_send, input_tensor_grad)
return self.recv_forward(model_chunk_id_recv)
return send_handles
return []
def send_forward_recv_forward(
self, model_chunk_id_send: int, model_chunk_id_recv: int, output_tensor: Any, send_prior: bool
):
if send_prior:
self.send_forward(model_chunk_id_send, output_tensor)
input_tensor = self.recv_forward(model_chunk_id_recv)
else:
input_tensor = self.recv_forward(model_chunk_id_recv)
self.send_forward(model_chunk_id_send, output_tensor)
self, model_chunk_id_send: int, model_chunk_id_recv: int, output_tensor: Any, send_first: bool = True
) -> Tuple[Any, List]:
with self.stage_manager.switch_model_chunk_id(model_chunk_id_send):
is_send = not self.stage_manager.is_last_stage()
with self.stage_manager.switch_model_chunk_id(model_chunk_id_recv):
is_recv = not self.stage_manager.is_first_stage()
input_tensor, wait_handles = self.comm.send_forward_recv_forward(
output_tensor,
is_send,
is_recv,
send_metadata=self.send_tensor_metadata,
metadata_recv=self.tensor_metadata_recv,
send_first=send_first,
)
# Cache metadata
self.send_tensor_metadata = not self.enable_metadata_cache and is_send
if is_recv and self.enable_metadata_cache and self.tensor_metadata_recv is None:
self.tensor_metadata_recv = create_send_metadata(input_tensor)
return input_tensor
return input_tensor, wait_handles
def send_backward_recv_backward(
self, model_chunk_id_send: int, model_chunk_id_recv: int, input_tensor_grad: Any, send_prior: bool
):
if send_prior:
self.send_backward(model_chunk_id_send, input_tensor_grad)
output_tensor_grad = self.recv_backward(model_chunk_id_recv)
else:
output_tensor_grad = self.recv_backward(model_chunk_id_recv)
self.send_backward(model_chunk_id_send, input_tensor_grad)
return output_tensor_grad
self, model_chunk_id_send: int, model_chunk_id_recv: int, input_tensor_grad: Any, send_first: bool = True
) -> Tuple[Any, List]:
with self.stage_manager.switch_model_chunk_id(model_chunk_id_send):
is_send = not self.stage_manager.is_first_stage()
with self.stage_manager.switch_model_chunk_id(model_chunk_id_recv):
is_recv = not self.stage_manager.is_last_stage()
output_tensor_grad, wait_handles = self.comm.send_backward_recv_backward(
input_tensor_grad,
is_send,
is_recv,
send_metadata=self.send_grad_metadata,
metadata_recv=self.grad_metadata_recv,
send_first=send_first,
)
# Cache metadata
self.send_grad_metadata = not self.enable_metadata_cache and is_send
if is_recv and self.enable_metadata_cache and self.grad_metadata_recv is None:
self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
return output_tensor_grad, wait_handles
def forward_step(
self,
@ -294,10 +280,12 @@ class InterleavedSchedule(PipelineSchedule):
Returns:
Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor).
"""
# Load input ids, attention mask and labels
micro_batch = self.load_micro_batch(model_chunk_id=model_chunk_id)
# for the first stage, input_obj is None
# for the non-first stage, input_obj is the output of the previous stage and it's must be a dict
# for other stages, input_obj is the output of the previous stage containing hidden_states etc.
# Only attention_mask from micro_batch is used
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
if isinstance(model_chunk, ModuleList):
@ -381,23 +369,27 @@ class InterleavedSchedule(PipelineSchedule):
if return_loss and self.stage_manager.is_last_stage(ignore_chunk=True):
accum_loss = torch.scalar_tensor(0, device=get_current_device())
fwd_wait_handles = []
model_chunk_id = self.get_model_chunk_id(0, is_forward=True)
input_obj = self.recv_forward(model_chunk_id)
input_obj, fwd_wait_handles = self.recv_forward(model_chunk_id)
for i in range(self.num_microbatch * self.num_model_chunks):
last_iteration = i == self.num_microbatch * self.num_model_chunks - 1
last_batch = i == self.num_microbatch * self.num_model_chunks - 1
model_chunk_id = self.get_model_chunk_id(i, is_forward=True)
# Wait until current input is received
_wait_p2p(fwd_wait_handles)
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
if not last_iteration:
input_obj = self.send_forward_recv_forward(
if not last_batch:
input_obj, fwd_wait_handles = self.send_forward_recv_forward(
model_chunk_id_send=model_chunk_id,
model_chunk_id_recv=self.get_model_chunk_id(i + 1, is_forward=True),
output_tensor=output_obj,
send_prior=self.stage_manager.stage % 2 == 0,
send_first=self.stage_manager.stage % 2 == 0,
)
else:
self.send_forward(model_chunk_id, output_obj)
fwd_wait_handles = self.send_forward(model_chunk_id, output_obj)
if outputs is not None:
outputs = merge_batch(outputs)
@ -420,7 +412,9 @@ class InterleavedSchedule(PipelineSchedule):
self.load_batch(data_iter)
num_microbatch = self.num_microbatch * self.num_model_chunks
# Forward + until 1st backward
num_warmup_microbatch = (self.stage_manager.num_stages - self.stage_manager.stage - 1) * 2
# Steps needed to reach the last chunk
num_warmup_microbatch += (self.num_model_chunks - 1) * self.stage_manager.num_stages
num_warmup_microbatch = min(num_warmup_microbatch, num_microbatch)
num_microbatch_remaining = num_microbatch - num_warmup_microbatch
@ -435,35 +429,44 @@ class InterleavedSchedule(PipelineSchedule):
if return_loss and self.stage_manager.is_last_stage(ignore_chunk=True):
accum_loss = torch.scalar_tensor(0, device=get_current_device())
bwd_wait_handles = []
# Get the 1st input batch
model_chunk_id = self.get_model_chunk_id(0, is_forward=True)
input_obj = self.recv_forward(model_chunk_id)
input_obj, fwd_wait_handles = self.recv_forward(model_chunk_id)
# Run warmup forward passes.
for i in range(num_warmup_microbatch):
last_iteration = i == num_warmup_microbatch - 1
last_batch = i == num_warmup_microbatch - 1
model_chunk_id = self.get_model_chunk_id(i, is_forward=True)
# Wait for input
_wait_p2p(fwd_wait_handles)
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
input_objs[model_chunk_id].append(input_obj)
output_objs[model_chunk_id].append(output_obj)
if last_iteration and num_microbatch_remaining == 0:
self.send_forward(model_chunk_id, output_obj)
if last_batch and num_microbatch_remaining == 0:
fwd_wait_handles = self.send_forward(model_chunk_id, output_obj)
else:
input_obj = self.send_forward_recv_forward(
input_obj, fwd_wait_handles = self.send_forward_recv_forward(
model_chunk_id_send=model_chunk_id,
model_chunk_id_recv=self.get_model_chunk_id(i + 1, is_forward=True),
output_tensor=output_obj,
send_prior=self.stage_manager.stage % 2 == 0,
send_first=self.stage_manager.stage % 2 == 0,
)
if num_microbatch_remaining > 0:
model_chunk_id = self.get_model_chunk_id(0, is_forward=False)
output_obj_grad = self.recv_backward(model_chunk_id)
output_obj_grad, bwd_wait_handles = self.recv_backward(model_chunk_id)
# Run 1F1B in steady state.
for i in range(num_microbatch_remaining):
last_iteration = i == num_microbatch_remaining - 1
fwd_batch_id = i + num_warmup_microbatch
last_batch = i == num_microbatch_remaining - 1
model_chunk_id = self.get_model_chunk_id(fwd_batch_id, is_forward=True)
model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatch, is_forward=True)
# Wait for input.
_wait_p2p(fwd_wait_handles)
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
# Add input_obj and output_obj to end of list.
input_objs[model_chunk_id].append(input_obj)
@ -473,64 +476,75 @@ class InterleavedSchedule(PipelineSchedule):
# Pop output_obj and output_obj from the start of the list for the backward pass.
_input_obj = input_objs[model_chunk_id].pop(0)
_output_obj = output_objs[model_chunk_id].pop(0)
input_obj_grad = self.backward_step(optimizer, _input_obj, _output_obj, output_obj_grad)
# NOTE: perform 2x communication for forward and backward
def send_forward_recv_backward():
if last_iteration and num_microbatch == num_microbatch_remaining:
model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatch, is_forward=True)
self.send_forward(model_chunk_id, output_obj)
# Helper functions
def send_forward_recv_forward():
if last_batch:
model_chunk_id = self.get_model_chunk_id(fwd_batch_id, is_forward=True)
wait_handles = self.send_forward(model_chunk_id, output_obj)
return None, wait_handles
else:
output_obj_grad = self.send_forward_recv_backward(
model_chunk_id_send=self.get_model_chunk_id(i + num_warmup_microbatch, is_forward=True),
model_chunk_id_recv=self.get_model_chunk_id(i + 1, is_forward=False),
input_obj, wait_handles = self.send_forward_recv_forward(
model_chunk_id_send=self.get_model_chunk_id(fwd_batch_id, is_forward=True),
model_chunk_id_recv=self.get_model_chunk_id(fwd_batch_id + 1, is_forward=True),
output_tensor=output_obj,
send_prior_fallback=self.stage_manager.stage % 2 == 0,
send_first=self.stage_manager.stage % 2 == 0
and i > 0, # Receive from warmup stage first in the first batch
)
return output_obj_grad
return input_obj, wait_handles
def send_backward_recv_forward():
if last_iteration:
def send_backward_recv_backward():
no_cooldown = num_microbatch == num_microbatch_remaining
if last_batch and no_cooldown:
model_chunk_id = self.get_model_chunk_id(i, is_forward=False)
self.send_backward(model_chunk_id, input_obj_grad)
wait_handles = self.send_backward(model_chunk_id, input_obj_grad)
return None, wait_handles
else:
input_obj = self.send_backward_recv_forward(
output_obj_grad, wait_handles = self.send_backward_recv_backward(
model_chunk_id_send=self.get_model_chunk_id(i, is_forward=False),
model_chunk_id_recv=self.get_model_chunk_id(i + num_warmup_microbatch + 1, is_forward=True),
model_chunk_id_recv=self.get_model_chunk_id(i + 1, is_forward=False),
input_tensor_grad=input_obj_grad,
send_prior_fallback=self.stage_manager.stage % 2 == 0 and i > 0,
send_first=self.stage_manager.stage % 2 == 0,
)
return input_obj
return output_obj_grad, wait_handles
if self.stage_manager.stage % 2 == 0:
output_obj_grad = send_forward_recv_backward()
input_obj = send_backward_recv_forward()
else:
input_obj = send_backward_recv_forward()
output_obj_grad = send_forward_recv_backward()
input_obj, fwd_wait_handles = send_forward_recv_forward()
# Wait for upstream grad
_wait_p2p(bwd_wait_handles)
input_obj_grad = self.backward_step(optimizer, _input_obj, _output_obj, output_obj_grad)
# NOTE: It's documented by NCCL that running two concurrent communicators (batch_isend_irecv)
# risks deadlock (https://docs.nvidia.com/deeplearning/nccl/archives/nccl_2134/user-guide/docs/usage/communicators.html)
# however in practice this works fine, and Megatron does this too
# (https://github.com/microsoft/Megatron-DeepSpeed/blob/bcedecd1ff788d4d363f3365fd396053a08d65be/megatron/core/pipeline_parallel/schedules.py#L774)
# if deadlock, call _wait_p2p(fwd_wait_handles) here
output_obj_grad, bwd_wait_handles = send_backward_recv_backward()
if num_microbatch_remaining == 0:
model_chunk_id = self.get_model_chunk_id(0, is_forward=False)
output_obj_grad = self.recv_backward(model_chunk_id)
output_obj_grad, bwd_wait_handles = self.recv_backward(model_chunk_id)
# Run cooldown backward passes.
for i in range(num_microbatch_remaining, num_microbatch):
last_iteration = i == num_microbatch - 1
last_batch = i == num_microbatch - 1
model_chunk_id = self.get_model_chunk_id(i, is_forward=False)
_input_obj = input_objs[model_chunk_id].pop(0)
_output_obj = output_objs[model_chunk_id].pop(0)
# output_obj_grad = self.recv_backward(model_chunk_id)
input_obj_grad = self.backward_step(optimizer, _input_obj, _output_obj, output_obj_grad)
if not last_iteration:
output_obj_grad = self.send_backward_recv_backward(
# Wait for upstream grad
_wait_p2p(bwd_wait_handles)
# backward local grads
input_obj_grad = self.backward_step(optimizer, _input_obj, _output_obj, output_obj_grad)
if not last_batch:
output_obj_grad, bwd_wait_handles = self.send_backward_recv_backward(
model_chunk_id_send=self.get_model_chunk_id(i, is_forward=False),
model_chunk_id_recv=self.get_model_chunk_id(i + 1, is_forward=False),
input_tensor_grad=input_obj_grad,
send_prior=self.stage_manager.stage % 2 == 0 and i > num_microbatch_remaining,
send_first=self.stage_manager.stage % 2 == 0 and i > num_microbatch_remaining,
)
assert (not self.overlap_p2p) or len(bwd_wait_handles) > 0
else:
model_chunk_id = self.get_model_chunk_id(i, is_forward=False)
self.send_backward(model_chunk_id, input_obj_grad)
_ = self.send_backward(model_chunk_id, input_obj_grad)
assert all(len(v) == 0 for v in input_objs) and all(len(v) == 0 for v in output_objs)

View File

@ -45,7 +45,8 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
num_microbatches is not None or microbatch_size is not None
), "Either num_microbatches or microbatch_size should be provided"
self.comm = PipelineP2PCommunication(stage_manager)
self.comm = PipelineP2PCommunication(stage_manager, overlap_p2p=False)
self.num_microbatches = num_microbatches
self.microbatch_size = microbatch_size
self.batch: Optional[Any] = None
@ -124,7 +125,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
Any: The input tensor or input tensor list.
"""
if not self.stage_manager.is_first_stage():
input_tensor = self.comm.recv_forward(prev_rank, metadata_recv=self.tensor_metadata_recv)
input_tensor, _ = self.comm.recv_forward(prev_rank, metadata_recv=self.tensor_metadata_recv)
if self.enable_metadata_cache and self.tensor_metadata_recv is None:
self.tensor_metadata_recv = create_send_metadata(input_tensor)
@ -141,7 +142,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
Any: The input gradient tensor or gradient tensor list.
"""
if not self.stage_manager.is_last_stage():
output_tensor_grad = self.comm.recv_backward(next_rank, metadata_recv=self.grad_metadata_recv)
output_tensor_grad, _ = self.comm.recv_backward(next_rank, metadata_recv=self.grad_metadata_recv)
if self.enable_metadata_cache and self.grad_metadata_recv is None:
self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
@ -171,9 +172,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
self.comm.send_backward(input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata)
self.send_grad_metadata = not self.enable_metadata_cache
def send_forward_recv_backward(
self, output_tensor: Any, next_rank: int = None, send_prior_fallback: Optional[bool] = None
) -> Any:
def send_forward_recv_backward(self, output_tensor: Any, send_first: Optional[bool] = None) -> Any:
"""Sends the input tensor to the next stage and copy the gradient tensor from the next stage in pipeline.
For 1F1B.
@ -183,13 +182,12 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
"""
if not self.stage_manager.is_last_stage():
if not self.send_tensor_metadata and self.grad_metadata_recv is not None:
send_prior_fallback = None # must not fallback
output_tensor_grad = self.comm.send_forward_recv_backward(
send_first = None
output_tensor_grad, _ = self.comm.send_forward_recv_backward(
output_tensor,
next_rank,
send_metadata=self.send_tensor_metadata,
metadata_recv=self.grad_metadata_recv,
send_prior_fallback=send_prior_fallback,
send_first=send_first,
)
self.send_tensor_metadata = not self.enable_metadata_cache
if self.enable_metadata_cache and self.grad_metadata_recv is None:
@ -197,9 +195,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
return output_tensor_grad
def send_backward_recv_forward(
self, input_tensor_grad: Any, prev_rank: int = None, send_prior_fallback: Optional[bool] = None
) -> Any:
def send_backward_recv_forward(self, input_tensor_grad: Any, send_first: Optional[bool] = None) -> Any:
"""Sends the gradient tensor to the previous stage and copy the input tensor from the previous stage in pipeline.
For 1F1B.
@ -209,13 +205,12 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
"""
if not self.stage_manager.is_first_stage():
if not self.send_grad_metadata and self.tensor_metadata_recv is not None:
send_prior_fallback = None # must not fallback
input_tensor = self.comm.send_backward_recv_forward(
send_first = None # must not fallback
input_tensor, _ = self.comm.send_backward_recv_forward(
input_tensor_grad,
prev_rank,
send_metadata=self.send_grad_metadata,
metadata_recv=self.tensor_metadata_recv,
send_prior_fallback=send_prior_fallback,
send_first=send_first,
)
self.send_grad_metadata = not self.enable_metadata_cache
if self.enable_metadata_cache and self.tensor_metadata_recv is None:
@ -381,9 +376,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
last_iteration = i == (num_microbatches_remaining - 1)
output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs)
output_obj_grad = self.send_forward_recv_backward(
output_obj, send_prior_fallback=self.stage_manager.stage % 2 == 0
)
output_obj_grad = self.send_forward_recv_backward(output_obj, send_first=self.stage_manager.stage % 2 == 0)
# Add input_obj and output_obj to end of list.
input_objs.append(input_obj)
output_objs.append(output_obj)
@ -398,7 +391,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
self.send_backward(input_obj_grad)
else:
input_obj = self.send_backward_recv_forward(
input_obj_grad, send_prior_fallback=self.stage_manager.stage % 2 == 0
input_obj_grad, send_first=self.stage_manager.stage % 2 == 0
)
# Run cooldown backward passes.

View File

@ -35,7 +35,7 @@ class PipelineStageManager:
self.pipeline_axis = pipeline_axis
self.prev_rank: Optional[Tuple[int, ...]] = None
self.next_rank: Optional[Tuple[int, ...]] = None
self.p2p_groups: Dict[Tuple[int, int], ProcessGroup] = {}
self.p2p_groups: Dict[Tuple[int, ...], ProcessGroup] = {}
if num_layers_per_stage is not None:
assert len(num_layers_per_stage) == self.num_stages
self.num_layers_per_stage = num_layers_per_stage
@ -48,30 +48,14 @@ class PipelineStageManager:
# the next rank of the last rank is rank0
next_coord = coord[: self.pipeline_axis] + (coord[self.pipeline_axis] + 1,) + coord[self.pipeline_axis + 1 :]
self.next_rank = self.pg_mesh.ravel(next_coord, self.pg_mesh.shape, mode="wrap")
# init p2p process groups
stages = list(range(self.num_stages))
for prev, cur in zip(stages[:-1], stages[1:]):
group = self.pg_mesh.get_group_along_axis(self.pipeline_axis, [prev, cur])
if self.stage in [prev, cur]:
ranks_in_group = self.pg_mesh.get_ranks_in_group(group)
self.p2p_groups[tuple(ranks_in_group)] = group
self.is_interleave = enable_interleave
# for interleaved pipeline parallel, each device is responsible for multiple chunk of layers
self.num_model_chunks: int = num_model_chunks
if enable_interleave:
# use circle p2p communication
# add the process group of the first rank and the last rank
group = self.pg_mesh.get_group_along_axis(self.pipeline_axis, [stages[0], stages[-1]])
if self.stage in [stages[0], stages[-1]]:
ranks_in_group = self.pg_mesh.get_ranks_in_group(group)
self.p2p_groups[tuple(ranks_in_group)] = group
# for shardformer, hold stage indices of model
self.stage_indices: List[Tuple[int, int]]
# for shardformer, hold model chunk id
self.model_chunk_id: Optional[int] = None
# for shardformer, hold stage indices of model
self.stage_indices: List[Tuple[int, int]]
# for shardformer, hold model chunk id
self.model_chunk_id: Optional[int] = None
self.p2p_group = self.pg_mesh.get_group_along_axis(self.pipeline_axis)
def get_stage_index(
self,
@ -184,19 +168,12 @@ class PipelineStageManager:
"""
return self.next_rank
def get_p2p_process_group(self, first_rank: int, second_rank: int) -> ProcessGroup:
def get_p2p_process_group(self) -> ProcessGroup:
"""Get the p2p process group between two ranks. The order of the two ranks does not matter.
Args:
first_rank (int): The first rank.
second_rank (int): The second rank.
Returns:
ProcessGroup: P2P process group between the two ranks.
"""
if first_rank > second_rank:
first_rank, second_rank = second_rank, first_rank
return self.p2p_groups[(first_rank, second_rank)]
return self.p2p_group
def init_process_group_by_stages(self, stages: List[int]) -> ProcessGroup:
"""Get the process group of the given stages.

View File

@ -674,6 +674,8 @@ class FusedLinear1D_Col(ParallelModule):
process_group (`Union[ProcessGroup, List[ProcessGroup]]`): The process group to be used for weight sharding and communication.
n_fused (int): The number of layers to be fused. In common, Q,K,V are fused in one weight.
"""
LazyInitContext.materialize(module)
# get the attributes
in_features = module.in_features
out_features = module.out_features

View File

@ -8,8 +8,15 @@ from transformers.modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
Seq2SeqLMOutput,
Seq2SeqModelOutput,
TokenClassifierOutput,
)
from transformers.models.t5.modeling_t5 import (
T5EncoderModel,
T5ForConditionalGeneration,
T5ForTokenClassification,
T5Model,
T5Stack,
)
from transformers.models.t5.modeling_t5 import T5EncoderModel, T5ForConditionalGeneration, T5Model, T5Stack
from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager
@ -582,6 +589,71 @@ class T5PipelineForwards:
return outputs
@staticmethod
def t5_for_token_classification_forward(
self: T5ForTokenClassification,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
position_bias: Optional[torch.Tensor] = None,
encoder_decoder_position_bias: Optional[torch.Tensor] = None,
labels: Optional[torch.LongTensor] = None,
backward_tensor_keys: Optional[List[str]] = None,
stage_index: Optional[List[int]] = None,
decoder_starting_stage: Optional[int] = None,
) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]:
r"""
This function is modified on the basis of transformers.models.t5.modeling_t5.T5ForTokenClassification.forward.
Please refer to original code of transformers for more details.
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = T5PipelineForwards.t5_stack_forward(
self.transformer.encoder,
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
stage_manager=stage_manager,
hidden_states=hidden_states,
position_bias=position_bias,
encoder_decoder_position_bias=encoder_decoder_position_bias,
stage_index=stage_index,
decoder_starting_stage=decoder_starting_stage,
)
if stage_manager.is_last_stage():
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output)
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
return outputs
def get_t5_flash_attention_forward():
from transformers.models.t5.modeling_t5 import T5Attention

View File

@ -68,6 +68,9 @@ _POLICY_LIST = {
file_name="t5", class_name="T5ForConditionalGenerationPolicy"
),
"transformers.models.t5.modeling_t5.T5EncoderModel": PolicyLocation(file_name="t5", class_name="T5EncoderPolicy"),
"transformers.models.t5.modeling_t5.T5ForTokenClassification": PolicyLocation(
file_name="t5", class_name="T5ForTokenClassificationPolicy"
),
# GPT2
"transformers.models.gpt2.modeling_gpt2.GPT2Model": PolicyLocation(file_name="gpt2", class_name="GPT2ModelPolicy"),
"transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel": PolicyLocation(

View File

@ -31,7 +31,13 @@ from ..modeling.t5 import (
)
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = ["distribute_t5_layers", "T5ModelPolicy", "T5ForConditionalGenerationPolicy", "T5EncoderPolicy"]
__all__ = [
"distribute_t5_layers",
"T5ModelPolicy",
"T5ForConditionalGenerationPolicy",
"T5EncoderPolicy",
"T5ForTokenClassificationPolicy",
]
class T5BasePolicy(Policy):
@ -312,9 +318,13 @@ class T5BasePolicy(Policy):
assert self.pipeline_stage_manager is not None
stage_manager = self.pipeline_stage_manager
model = self.model
encoder = self.model.encoder
decoder = getattr(self.model, "decoder", None)
if self.model.__class__.__name__ == "T5ForTokenClassification":
model = self.model.transformer
else:
model = self.model
encoder = model.encoder
decoder = getattr(model, "decoder", None)
num_encoder_layers = len(encoder.block)
num_decoder_layers = len(decoder.block) if decoder else 0
@ -353,7 +363,11 @@ class T5BasePolicy(Policy):
raise ValueError("set_pipeline_forward method can only be called when pipeline parallel is enabled.")
stage_manager = self.pipeline_stage_manager
encoder = self.model.encoder
if self.model.__class__.__name__ == "T5ForTokenClassification":
encoder = self.model.transformer.encoder
else:
encoder = self.model.encoder
decoder = getattr(self.model, "decoder", None)
num_encoder_layers = len(encoder.block)
@ -542,3 +556,46 @@ class T5EncoderPolicy(T5BasePolicy):
def get_shared_params(self) -> List[Dict[int, Tensor]]:
return []
class T5ForTokenClassificationPolicy(T5EncoderPolicy):
def module_policy(self):
from transformers.models.t5.modeling_t5 import T5ForTokenClassification
policy = super().module_policy()
if self.shard_config.enable_tensor_parallelism:
addon_module = {
T5ForTokenClassification: ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=DropoutForParallelInput,
)
]
)
}
policy.update(addon_module)
if self.pipeline_stage_manager:
self.set_pipeline_forward(
model_cls=T5ForTokenClassification,
new_forward=T5PipelineForwards.t5_for_token_classification_forward,
policy=policy,
)
return policy
def get_held_layers(self) -> List[nn.Module]:
"""
get pipeline layers for current stage
"""
held_layers = super().get_held_layers()
stage_manager = self.pipeline_stage_manager
if stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.dropout)
held_layers.append(self.model.classifier)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
# no shared params for sequence classification model
return []

View File

@ -403,9 +403,9 @@ class Chunk:
self.shard_size, dtype=self.dtype, device=get_accelerator().get_current_device()
)
input_list = list(torch.chunk(self.cuda_global_chunk, chunks=self.pg_size, dim=0))
self.grad_reduce_work = dist.reduce_scatter(
self.cuda_shard, input_list, group=self.torch_pg, async_op=async_op
assert self.cuda_global_chunk.is_contiguous()
self.grad_reduce_work = dist.reduce_scatter_tensor(
self.cuda_shard, self.cuda_global_chunk, group=self.torch_pg, async_op=async_op
)
if self.extra_dp_group is not None:
@ -520,8 +520,10 @@ class Chunk:
assert self.cuda_shard is not None
alloc_storage(self.cuda_global_chunk)
gather_list = list(torch.chunk(input=self.cuda_global_chunk, chunks=self.pg_size, dim=0))
work = dist.all_gather(gather_list, self.cuda_shard, self.torch_pg, async_op=async_op)
assert self.cuda_global_chunk.is_contiguous()
work = dist.all_gather_into_tensor(
self.cuda_global_chunk, self.cuda_shard, self.torch_pg, async_op=async_op
)
self.cuda_shard = None
self.is_gathered = True

View File

@ -133,12 +133,12 @@ class ChunkManager:
self.__sub_accessed_chunk(chunk)
self.__add_memory_usage(chunk.memory_usage)
def move_chunk(self, chunk: Chunk, device: torch.device, force_copy: bool = False) -> None:
def move_chunk(self, chunk: Chunk, device: torch.device, force_copy: bool = False, async_move=False) -> None:
"""Move the shard of the chunk to the target device."""
if not chunk.can_move or chunk.device_type == device.type:
return
self.__sub_memory_usage(chunk.memory_usage)
chunk.shard_move(device, force_copy)
chunk.shard_move(device, force_copy, non_blocking=async_move)
self.__add_memory_usage(chunk.memory_usage)
def trans_tensor_state(self, tensor: torch.Tensor, state: TensorState) -> None:

View File

@ -387,6 +387,7 @@ class GeminiDDP(ModelWrapper):
p: nn.Parameter,
async_reduce_stream: Optional[torch.cuda.Stream] = None,
):
async_reduce_scatter = async_reduce_stream is not None
setattr(p, "_gemini_reduced", True)
empty_grad = torch.empty_like(grad)
free_storage(empty_grad)
@ -426,7 +427,7 @@ class GeminiDDP(ModelWrapper):
async_reduce_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(async_reduce_stream):
reduced = chunk_manager.reduce_chunk(grad_chunk, async_op=(async_reduce_stream is not None))
reduced = chunk_manager.reduce_chunk(grad_chunk, async_op=async_reduce_scatter)
if reduced:
grad_chunk.wait_async_reduce()
if not chunk_manager.reuse_fp16_chunk:
@ -447,9 +448,13 @@ class GeminiDDP(ModelWrapper):
# record l2 norm for gradient clipping. flag is bound to fp16 chunk
if chunk.l2_norm_flag:
grad_chunk.set_l2_norm()
chunk_manager.move_chunk(grad_chunk, grads_device[p], force_copy=True)
chunk_manager.move_chunk(
grad_chunk, grads_device[p], force_copy=True, async_move=async_reduce_scatter
)
if not (master_weights) or (enable_gradient_accumulation):
chunk_manager.move_chunk(chunk, grads_device[p], force_copy=True)
chunk_manager.move_chunk(
chunk, grads_device[p], force_copy=True, async_move=async_reduce_scatter
)
return empty_grad
def zero_grad(self, set_to_none: bool = False) -> None:

View File

@ -1,3 +1,7 @@
from typing import Optional
import torch
import torch.distributed as dist
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
@ -6,6 +10,7 @@ class TensorBucket:
self._max_size = size
self._current_size = 0
self._bucket = []
self._write_back_pairs = {}
@property
def max_size(self):
@ -21,7 +26,7 @@ class TensorBucket:
def is_empty(self):
return len(self._bucket) == 0
def add_to_bucket(self, tensor, allow_oversize=False):
def add_to_bucket(self, tensor, allow_oversize=False, write_back_tensor: Optional[torch.Tensor] = None):
tensor_size = tensor.numel()
if not allow_oversize and self.will_exceed_max_size(tensor_size):
@ -30,6 +35,8 @@ class TensorBucket:
self._bucket.append(tensor)
self._current_size += tensor_size
write_back_tensor = write_back_tensor if write_back_tensor is not None else tensor
self._write_back_pairs[tensor] = write_back_tensor
def will_exceed_max_size(self, tensor_size):
expected_size = self._current_size + tensor_size
@ -40,12 +47,30 @@ class TensorBucket:
def empty(self):
self._bucket = []
self._size = 0
self._current_size = 0
self._write_back_pairs = {}
def flatten(self):
return _flatten_dense_tensors(self._bucket)
def unflatten(self, flat_tensor):
return _unflatten_dense_tensors(flat_tensor, self._bucket)
def unflatten_and_copy(self, flat_tensor):
unflattened_tensor_list = _unflatten_dense_tensors(flat_tensor, self._bucket)
unflattened_tensor_list = self.unflatten(flat_tensor)
for old, new in zip(self._bucket, unflattened_tensor_list):
old.copy_(new)
def all_gather(self, group=None):
flat = self.flatten()
buffers = [torch.empty_like(flat) for _ in range(dist.get_world_size(group))]
dist.all_gather(buffers, flat, group=group)
unflat_buffers = [self.unflatten(buffer) for buffer in buffers]
# transpose the list of list
unflat_buffers = list(map(list, zip(*unflat_buffers)))
for unflat_shards, tensor in zip(unflat_buffers, self._bucket):
write_back_tensor = self._write_back_pairs[tensor]
write_back_tensor.data.copy_(
_flatten_dense_tensors(unflat_shards)[: write_back_tensor.numel()].reshape_as(write_back_tensor)
)
self.empty()

View File

@ -23,7 +23,7 @@ from colossalai.logging import get_dist_logger
from colossalai.tensor.moe_tensor.api import is_moe_tensor
from ._utils import calculate_global_norm_from_list, flatten, has_inf_or_nan, release_param_grad, sync_tensor
from .bookkeeping import BucketStore, GradientStore, ParameterStore
from .bookkeeping import BucketStore, GradientStore, ParameterStore, TensorBucket
class LowLevelZeroFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):
@ -694,34 +694,33 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
for group_id in range(self.num_param_groups):
release_param_grad(self._master_param_groups_of_current_rank[group_id])
tensor_bucket = TensorBucket(self._bucket_store.reduce_bucket_size)
moe_tensor_bucket = TensorBucket(self._bucket_store.reduce_bucket_size)
# update working partition updated by the current rank
device = get_accelerator().get_current_device()
for group_id in range(self.num_param_groups):
master_working_param = self.optim.param_groups[group_id]["params"]
for idx, splited_param in enumerate(master_working_param):
working_param = real_working_params[group_id][idx]
param_to_gather = splited_param.to(device).to(self._dtype)
if self._bucket_store.moe_extra_dp_pg is not None and is_moe_tensor(working_param):
all_splited_param = [
torch.zeros(splited_param.shape, device=device, dtype=self._dtype)
for _ in range(self._bucket_store.moe_extra_dp_pg_size)
]
dist.all_gather(
all_splited_param,
splited_param.to(device).to(self._dtype),
group=self._bucket_store.moe_extra_dp_pg,
)
try:
moe_tensor_bucket.add_to_bucket(param_to_gather, write_back_tensor=working_param)
except RuntimeError:
moe_tensor_bucket.all_gather(self._bucket_store.moe_extra_dp_pg)
moe_tensor_bucket.add_to_bucket(param_to_gather, write_back_tensor=working_param)
else:
all_splited_param = [
torch.zeros(splited_param.shape, device=device, dtype=self._dtype)
for _ in range(self._bucket_store.zero_world_size)
]
dist.all_gather(
all_splited_param,
splited_param.to(device).to(self._dtype),
group=self._bucket_store.torch_pg,
)
working_param.data.copy_(flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param))
try:
tensor_bucket.add_to_bucket(param_to_gather, write_back_tensor=working_param)
except RuntimeError:
tensor_bucket.all_gather(self._bucket_store.moe_extra_dp_pg)
tensor_bucket.add_to_bucket(param_to_gather, write_back_tensor=working_param)
self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id]
if not moe_tensor_bucket.is_empty():
moe_tensor_bucket.all_gather(self._bucket_store.moe_extra_dp_pg)
if not tensor_bucket.is_empty():
tensor_bucket.all_gather(self._bucket_store.torch_pg)
def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> float:
r"""

View File

@ -9,6 +9,7 @@
<a href="https://www.colossalai.org/"> 文档 </a> |
<a href="https://github.com/hpcaitech/ColossalAI/tree/main/examples"> 例程 </a> |
<a href="https://github.com/hpcaitech/ColossalAI/discussions"> 论坛 </a> |
<a href="https://cloud.luchentech.com/">潞晨云 </a> |
<a href="https://hpc-ai.com/blog"> 博客 </a></h3>
[![GitHub Repo stars](https://img.shields.io/github/stars/hpcaitech/ColossalAI?style=social)](https://github.com/hpcaitech/ColossalAI/stargazers)
@ -127,6 +128,8 @@ Colossal-AI 为您提供了一系列并行组件。我们的目标是让您的
[[博客]](https://hpc-ai.com/blog/open-sora-from-hpc-ai-tech-team-continues-open-source-generate-any-16-second-720p-hd-video-with-one-click-model-weights-ready-to-use)
[[模型权重]](https://github.com/hpcaitech/Open-Sora?tab=readme-ov-file#model-weights)
[[演示样例]](https://github.com/hpcaitech/Open-Sora?tab=readme-ov-file#-latest-demo)
[[潞晨云]](https://cloud.luchentech.com/)
[[OpenSora镜像]](https://cloud.luchentech.com/doc/docs/image/open-sora/)
<div align="center">
<a href="https://www.bilibili.com/video/BV1Fm421G7bV">
@ -135,6 +138,8 @@ Colossal-AI 为您提供了一系列并行组件。我们的目标是让您的
</div>
### Colossal-LLaMA-2
[[潞晨云]](https://cloud.luchentech.com/)
[[LLaMA3 镜像]](https://cloud.luchentech.com/doc/docs/image/llama)
- 7B千元预算半天训练效果媲美主流大模型开源可商用中文LLaMA-2
[[代码]](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Colossal-LLaMA-2)
@ -265,7 +270,9 @@ Colossal-AI 为您提供了一系列并行组件。我们的目标是让您的
</p>
- 700亿参数LLaMA3训练加速18%
[[code]](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/llama)
[[代码]](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/llama)
[[潞晨云]](https://cloud.luchentech.com/)
[[LLaMA3 镜像]](https://cloud.luchentech.com/doc/docs/image/llama)
### LLaMA2
<p align="center">
@ -378,6 +385,8 @@ Colossal-AI 为您提供了一系列并行组件。我们的目标是让您的
- AI大模型推理速度部分接近翻倍与vLLM的离线推理性能相比
[[代码]](https://github.com/hpcaitech/ColossalAI/tree/main/colossalai/inference)
[[博客]](https://hpc-ai.com/blog/colossal-inference)
[[潞晨云]](https://cloud.luchentech.com/)
[[LLaMA3 镜像]](https://cloud.luchentech.com/doc/docs/image/llama)
### Grok-1
<p id="Grok-1" align="center">

View File

@ -1,4 +1,12 @@
# Colossal-AI Examples
<div align="center">
<h3>
<a href="https://cloud.luchentech.com/">GPU Cloud Playground </a> </a> |
<a href="https://cloud.luchentech.com/doc/docs/intro"> Playground Document </a>
</h3>
</div>
## Table of Contents

View File

@ -1,9 +1,11 @@
import argparse
import resource
import time
import warnings
from contextlib import nullcontext
import torch
import torch.distributed as dist
from data_utils import RandomDataset
from model_utils import format_numel_str, get_model_numel
from performance_evaluator import PerformanceEvaluator, get_profile_context
@ -21,11 +23,19 @@ from colossalai.lazy import LazyInitContext
from colossalai.nn.optimizer import HybridAdam
from colossalai.shardformer import PipelineGradientCheckpointConfig
warnings.filterwarnings("ignore")
# ==============================
# Constants
# ==============================
MODEL_CONFIGS = {
"100m": LlamaConfig(
max_position_embeddings=4096,
num_hidden_layers=4,
num_attention_heads=32,
intermediate_size=2048,
hidden_size=1024,
),
"7b": LlamaConfig(max_position_embeddings=4096),
"13b": LlamaConfig(
hidden_size=5120,
@ -58,6 +68,9 @@ def main():
default="gemini",
help="Choose which plugin to use",
)
parser.add_argument(
"--overlap", action="store_true", help="Overlap communication with computation in Pipeline Parallel."
)
parser.add_argument("-b", "--batch_size", type=int, default=2, help="Batch size")
parser.add_argument("-s", "--num_steps", type=int, default=5, help="Number of steps to run")
parser.add_argument("-i", "--ignore_steps", type=int, default=2, help="Number of steps to ignore")
@ -78,11 +91,13 @@ def main():
parser.add_argument("--mbs", type=int, default=1, help="Micro batch size of pipeline parallel")
parser.add_argument("--zero", type=int, default=0, help="Zero Stage when hybrid plugin is enabled")
parser.add_argument("--custom-ckpt", action="store_true", help="Customize checkpoint", default=False)
parser.add_argument("--profile", action="store_true", help="Enable profiling", default=False)
parser.add_argument(
"--disable-async-reduce", action="store_true", help="Disable the asynchronous reduce operation", default=False
)
parser.add_argument("--pp_style", default="1f1b", choices=["1f1b", "interleaved"])
parser.add_argument("--n_chunks", default=1, help="number of model chunks", type=eval)
parser.add_argument("--profile", action="store_true", help="Profile the code", default=False)
parser.add_argument("--disable-async-reduce", action="store_true", help="Disable the asynchronous reduce operation")
parser.add_argument("--prefetch_num", type=int, default=0, help="chunk prefetch max number")
parser.add_argument("--no_cache", action="store_true")
args = parser.parse_args()
colossalai.launch_from_torch()
@ -98,6 +113,7 @@ def main():
num_ckpt_layers_per_stage=[19, 19, 19, 13],
),
"num_layers_per_stage": [19, 20, 20, 21],
"pp_style": "interleaved",
}
if args.custom_ckpt
else {}
@ -174,6 +190,8 @@ def main():
plugin = HybridParallelPlugin(
tp_size=args.tp,
pp_size=args.pp,
pp_style=args.pp_style,
num_model_chunks=args.n_chunks,
zero_stage=args.zero,
sp_size=args.sp,
enable_sequence_parallelism=args.sp > 1,
@ -182,12 +200,16 @@ def main():
microbatch_size=args.mbs,
precision="bf16",
dp_outside=False,
overlap_p2p=args.overlap,
enable_metadata_cache=not args.no_cache,
**hybrid_kwargs,
)
elif args.plugin == "3d_cpu":
plugin = HybridParallelPlugin(
tp_size=args.tp,
pp_size=args.pp,
pp_style=args.pp_style,
num_model_chunks=args.n_chunks,
zero_stage=args.zero,
cpu_offload=True,
enable_fused_normalization=torch.cuda.is_available(),
@ -195,6 +217,7 @@ def main():
microbatch_size=args.mbs,
initial_scale=2**8,
precision="bf16",
overlap_p2p=args.overlap,
)
else:
raise ValueError(f"Unknown plugin {args.plugin}")
@ -210,10 +233,11 @@ def main():
config = MODEL_CONFIGS[args.config]
else:
config = AutoConfig.from_pretrained(args.config, trust_remote_code=True)
torch.cuda.manual_seed(42)
dataset = RandomDataset(
num_samples=args.batch_size * args.num_steps * dp_size, max_length=args.max_length, vocab_size=config.vocab_size
)
dataloader = plugin.prepare_dataloader(dataset, batch_size=args.batch_size, shuffle=True, drop_last=True)
dataloader = plugin.prepare_dataloader(dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, seed=42)
# ==============================
# Initialize Model and Optimizer
@ -229,8 +253,13 @@ def main():
init_kwargs["empty_init"] = False
with init_ctx:
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True, **init_kwargs)
model = AutoModelForCausalLM.from_config(
config,
trust_remote_code=True,
**init_kwargs,
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16,
)
if args.grad_checkpoint:
model.gradient_checkpointing_enable()
if config.model_type == "chatglm":
@ -251,6 +280,7 @@ def main():
optimizer = HybridAdam(model.parameters())
torch.set_default_dtype(torch.bfloat16)
model, optimizer, _, dataloader, _ = booster.boost(model, optimizer, dataloader=dataloader)
torch.set_default_dtype(torch.float)
coordinator.print_on_master(
f"Booster init max CUDA memory: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB"
@ -261,7 +291,7 @@ def main():
with get_profile_context(
args.profile,
1,
args.ignore_steps,
len(dataloader) - 1,
save_dir=f"profile/{time.strftime('%H:%M', time.localtime())}-{args.plugin}-llama-{args.config}",
) as prof:
@ -269,15 +299,19 @@ def main():
data_iter = iter(dataloader)
for step in tqdm(range(len(dataloader)), desc="Step", disable=not coordinator.is_master()):
performance_evaluator.on_step_start(step)
booster.execute_pipeline(
outputs = booster.execute_pipeline(
data_iter,
model,
criterion=lambda outputs, inputs: outputs[0],
optimizer=optimizer,
return_loss=False,
return_loss=True,
)
loss = outputs["loss"]
if dist.get_rank() == dist.get_world_size() - 1:
print(f"Step {step} loss: {loss}")
optimizer.step()
optimizer.zero_grad()
performance_evaluator.on_step_end(input_ids=torch.empty(args.batch_size, args.max_length))
prof.step()
else:
@ -288,6 +322,7 @@ def main():
booster.backward(loss, optimizer)
optimizer.step()
optimizer.zero_grad()
performance_evaluator.on_step_end(**batch)
prof.step()

View File

@ -40,6 +40,14 @@ def data_gen_for_t5_model():
return data
def data_gen_for_token_classification():
# token classification data gen
# `labels` is the type not the token id for token classification, 0 or 1
data = data_gen_for_encoder_only()
data["labels"] = torch.tensor([[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], dtype=torch.int64)
return data
# output transform function
output_transform_fn = lambda x: x
@ -47,6 +55,7 @@ output_transform_fn = lambda x: x
loss_fn_for_t5_model = lambda x: x["last_hidden_state"].mean()
loss_fn_for_encoder_only = lambda x: x["last_hidden_state"].mean()
loss_fn_for_conditional_generation = lambda x: x["loss"]
loss_fn_for_token_classification = lambda x: x["loss"]
# define model config
config = transformers.T5Config(d_model=128, num_layers=2, dropout_rate=0, decoder_start_token_id=0)
@ -79,3 +88,11 @@ model_zoo.register(
loss_fn=loss_fn_for_encoder_only,
model_attribute=ModelAttribute(has_control_flow=True),
)
model_zoo.register(
name="transformers_t5_for_token_classification",
model_fn=lambda: transformers.T5ForTokenClassification(config),
data_gen_fn=data_gen_for_token_classification,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_token_classification,
model_attribute=ModelAttribute(has_control_flow=True),
)

View File

@ -15,8 +15,7 @@ WORLD_SIZE = 2
def check_p2p_communication():
pg_mesh = ProcessGroupMesh(WORLD_SIZE)
stage_manager = PipelineStageManager(pg_mesh, 0)
p2p = PipelineP2PCommunication(stage_manager)
p2p = PipelineP2PCommunication(stage_manager, overlap_p2p=False)
rank = dist.get_rank()
tensor = torch.ones(1, device=get_accelerator().get_current_device())
@ -31,41 +30,40 @@ def check_p2p_communication():
for obj in data:
p2p.send_forward(obj)
for i in range(len(data)):
recv_obj = p2p.send_forward_recv_backward(data[i], send_prior_fallback=False)
recv_obj, _ = p2p.send_forward_recv_backward(data[i], send_first=False)
assert recv_obj == data[-(i + 1)]
elif rank == 1:
for obj in data:
recv_obj = p2p.recv_forward()
recv_obj, _ = p2p.recv_forward()
assert recv_obj == obj
for i in range(len(data)):
p2p.send_backward(data[-(i + 1)])
recv_obj = p2p.recv_forward()
recv_obj, _ = p2p.recv_forward()
assert recv_obj == data[i]
if rank == 1:
for obj in data:
p2p.send_backward(obj)
for i in range(len(data)):
recv_obj = p2p.send_backward_recv_forward(data[i], send_prior_fallback=True)
recv_obj, _ = p2p.send_backward_recv_forward(data[i], send_first=True)
assert recv_obj == data[-(i + 1)]
elif rank == 0:
for obj in data:
recv_obj = p2p.recv_backward()
recv_obj, _ = p2p.recv_backward()
assert recv_obj == obj
for i in range(len(data)):
recv_obj = p2p.recv_backward()
p2p.send_forward(data[-(i + 1)])
recv_obj, _ = p2p.send_forward_recv_backward(data[-(i + 1)], send_first=False)
assert recv_obj == data[i]
if rank == 0:
recv_obj = p2p.send_forward_recv_backward(
recv_obj, _ = p2p.send_forward_recv_backward(
tensor,
send_metadata=False,
metadata_recv=create_send_metadata(tensor),
)
assert recv_obj == tensor
elif rank == 1:
recv_obj = p2p.recv_forward(metadata_recv=create_send_metadata(tensor))
recv_obj, _ = p2p.recv_forward(metadata_recv=create_send_metadata(tensor))
assert recv_obj == tensor
p2p.send_backward(tensor, send_metadata=False)

View File

@ -52,7 +52,7 @@ def check_stage_manager():
# check p2p groups
for prev, cur in zip(ranks_in_group[:-1], ranks_in_group[1:]):
if rank in [prev, cur]:
group = stage_manager.get_p2p_process_group(prev, cur)
group = stage_manager.get_p2p_process_group()
dist.barrier(group=group)
# check stage groups

View File

@ -41,14 +41,17 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
t5 = unwrap_model(org_model)
sharded_t5 = unwrap_model(sharded_model)
row_layer_for_check = ["shared", "encoder.block[0].layer[0].SelfAttention.q"]
if t5.__class__.__name__ == "T5ForTokenClassification":
row_layer_for_check = ["transformer.shared", "transformer.encoder.block[0].layer[0].SelfAttention.q"]
else:
row_layer_for_check = ["shared", "encoder.block[0].layer[0].SelfAttention.q"]
# Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
grads_to_check = {}
if test_config["precision"] == "fp32":
atol, rtol = 1e-5, 1e-3
else:
atol, rtol = 5e-3, 5e-3
atol, rtol = 5e-2, 5e-2
if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
row_layer_grads = get_grad_tensors_for_check(
t5, sharded_t5, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0
@ -66,7 +69,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
else:
atol, rtol = 5e-3, 5e-3
if org_model.__class__.__name__ != "T5ForConditionalGeneration":
if org_model.__class__.__name__ not in ["T5ForConditionalGeneration", "T5ForTokenClassification"]:
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
@ -157,7 +160,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
)
@clear_cache_before_run()
def run_t5_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry("transformers_t5")
sub_model_zoo = model_zoo.get_sub_registry(["transformers_t5_for_token_classification"])
for name, (
model_fn,
@ -167,7 +170,10 @@ def run_t5_test(test_config):
_,
) in sub_model_zoo.items():
# skip 4-stage pp test for t5_encoder
if test_config["pp_size"] > 2 and name == "transformers_t5_encoder_model":
if test_config["pp_size"] > 2 and name in [
"transformers_t5_encoder_model",
"transformers_t5_for_token_classification",
]:
continue
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)

View File

@ -1 +1 @@
0.3.9
0.4.0