diff --git a/README.md b/README.md
index 12d29727b..69506e338 100644
--- a/README.md
+++ b/README.md
@@ -9,6 +9,7 @@
Documentation |
Examples |
Forum |
+ GPU Cloud Playground |
Blog
[data:image/s3,"s3://crabby-images/0b579/0b579880349b54d34ed6c79ebcb618fb303a421a" alt="GitHub Repo stars"](https://github.com/hpcaitech/ColossalAI/stargazers)
@@ -132,6 +133,8 @@ distributed training and inference in a few lines.
[[blog]](https://hpc-ai.com/blog/open-sora-from-hpc-ai-tech-team-continues-open-source-generate-any-16-second-720p-hd-video-with-one-click-model-weights-ready-to-use)
[[Model weights]](https://github.com/hpcaitech/Open-Sora?tab=readme-ov-file#model-weights)
[[Demo]](https://github.com/hpcaitech/Open-Sora?tab=readme-ov-file#-latest-demo)
+[[GPU Cloud Playground]](https://cloud.luchentech.com/)
+[[OpenSora Image]](https://cloud.luchentech.com/doc/docs/image/open-sora/)
## Table of Contents
diff --git a/applications/ColossalEval/README.md b/applications/ColossalEval/README.md
index a1a76f750..890b1fed3 100644
--- a/applications/ColossalEval/README.md
+++ b/applications/ColossalEval/README.md
@@ -2,6 +2,12 @@
+
+
+
## Table of Contents
diff --git a/applications/README.md b/applications/README.md
index e7c23c7e9..5b8b5e501 100644
--- a/applications/README.md
+++ b/applications/README.md
@@ -2,6 +2,15 @@
This directory contains the applications that are powered by Colossal-AI.
+
+
The list of applications include:
- [X] [Open-Sora](https://github.com/hpcaitech/Open-Sora): Revealing Complete Model Parameters, Training Details, and Everything for Sora-like Video Generation Models
diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py
index 474b78aa2..ad131fbe7 100644
--- a/colossalai/booster/plugin/gemini_plugin.py
+++ b/colossalai/booster/plugin/gemini_plugin.py
@@ -369,9 +369,9 @@ class GeminiPlugin(DPPluginBase):
assert precision in SUPPORTED_PRECISION, f"precision {precision} is not supported"
if get_accelerator().name == "npu":
assert placement_policy == "static", "NPU only supports static placement policy"
- if placement_policy == "auto" and enable_async_reduce:
+ if enable_async_reduce and not pin_memory:
logging.warning(
- f"enable_async_reduce requires pin_memory to achieve best performance, which is not implicitly set."
+ f"enable_async_reduce sets pin_memory=True to achieve best performance, which is not implicitly set."
)
pin_memory = True
self.gemini_config = dict(
diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py
index fa3c3646a..3bd43f172 100644
--- a/colossalai/booster/plugin/hybrid_parallel_plugin.py
+++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py
@@ -946,7 +946,7 @@ class HybridParallelPlugin(PipelinePluginBase):
gradient_checkpoint_config (GradientCheckpointConfig, optional): Configuration for gradient checkpointing. Defaults to None.
enable_metadata_cache (bool, optional): Whether to enable metadata cache for pipeline parallelism. Defaults to True.
make_vocab_size_divisible_by (int, optional): it's used when padding the vocabulary size, to make it choose an faster kenel. Default to 64.
-
+ overlap_p2p (bool, optional): Whether to overlap the p2p communication in pipeline parallelism
"""
def __init__(
@@ -992,6 +992,7 @@ class HybridParallelPlugin(PipelinePluginBase):
enable_metadata_cache: bool = True,
make_vocab_size_divisible_by: int = 64,
dp_outside: bool = True,
+ overlap_p2p: bool = True,
) -> None:
super().__init__()
assert (
@@ -1062,7 +1063,9 @@ class HybridParallelPlugin(PipelinePluginBase):
assert (
num_microbatches is not None or microbatch_size is not None
), "num_microbatches or microbatch_size must be specified when using pipeline parallelism"
- assert self.zero_stage <= 1, "zero stage must be 0 or 1 when using pipeline parallelism"
+ assert (
+ self.zero_stage <= 1
+ ), "To avoid prohibitive gradient synchronization costs, zero stage must be 0 or 1 when using pipeline parallelism"
self.stage_manager = PipelineStageManager(
self.pg_mesh,
pipeline_axis=self.pp_axis,
@@ -1079,6 +1082,7 @@ class HybridParallelPlugin(PipelinePluginBase):
num_microbatch=num_microbatches,
microbatch_size=microbatch_size,
enable_metadata_cache=enable_metadata_cache,
+ overlap_p2p=overlap_p2p,
)
elif pp_style == "1f1b":
self.schedule = OneForwardOneBackwardSchedule(
diff --git a/colossalai/cluster/process_group_mesh.py b/colossalai/cluster/process_group_mesh.py
index fea4a23ba..f0cb78c5f 100644
--- a/colossalai/cluster/process_group_mesh.py
+++ b/colossalai/cluster/process_group_mesh.py
@@ -134,7 +134,7 @@ class ProcessGroupMesh:
"""
assert mode in ["raise", "wrap", "clip"]
- return np.ravel_multi_index(coord, shape, mode)
+ return int(np.ravel_multi_index(coord, shape, mode))
def get_group(self, ranks_in_group: List[int], backend: Optional[str] = None) -> ProcessGroup:
"""Get the process group with the given ranks. It the process group doesn't exist, it will be created.
@@ -182,7 +182,7 @@ class ProcessGroupMesh:
axis = [
axis,
]
- assert isinstance(indices_at_axis[0], int)
+ assert isinstance(indices_at_axis[0], int), f"Expected int, but got {type(indices_at_axis[0])}."
indices_at_axis = [
indices_at_axis,
]
diff --git a/colossalai/inference/core/engine.py b/colossalai/inference/core/engine.py
index a1b54fa1c..f0918c88c 100644
--- a/colossalai/inference/core/engine.py
+++ b/colossalai/inference/core/engine.py
@@ -24,8 +24,9 @@ from colossalai.inference.modeling.policy import model_policy_map
from colossalai.inference.sampler import search_tokens
from colossalai.inference.spec import Drafter, GlideInput
from colossalai.inference.struct import Sequence
-from colossalai.inference.utils import get_model_size
+from colossalai.inference.utils import get_model_size, has_index_file
from colossalai.interface import ModelWrapper
+from colossalai.lazy import LazyInitContext
from colossalai.logging import get_dist_logger
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer
@@ -122,16 +123,24 @@ class InferenceEngine:
model_inference_config: the configuration for modeling initialization when inference.
model_shard_infer_config (ModelShardInferenceConfig): the configuration for init of module when inference.
"""
-
+ pretrained_path = None
if isinstance(model_or_path, str):
+ import colossalai.interface.pretrained as pretrained_utils
+
try:
- hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True)
+ hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True, torch_dtype=self.dtype)
arch = getattr(hf_config, "architectures")[0]
if arch in _supported_models.keys():
- # NOTE(lry89757) Currently we load the model using transformers-api,
- # but we will use lazy tensor and checkpoint io to accelerate
- # the model load process in the future.
- model = _supported_models[arch].from_pretrained(model_or_path, trust_remote_code=True)
+ if arch is "BaichuanForCausalLM":
+ self.logger.warning(
+ "Attention ! We use lazy init by default, which could be faster for model loading. For baichuan model, the output maybe have a slight difference with transformers"
+ )
+ ctx = LazyInitContext(default_device="cuda")
+ with ctx:
+ model = _supported_models[arch].from_pretrained(
+ model_or_path, trust_remote_code=True, torch_dtype=self.dtype
+ )
+ pretrained_path = pretrained_utils.get_pretrained_path(model)
else:
# TODO(char-1ee): if the model not supported, use transformers APIs to load and generate
raise ValueError(f"Model {arch} is not supported.")
@@ -189,14 +198,13 @@ class InferenceEngine:
f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}"
)
- # NOTE(lry89757) Deprecated currently, will reused when introduce lazy tensor
- # if isinstance(model_or_path, str) and not isinstance(casuallm, AutoModelForCausalLM):
- # from colossalai.inference.core.plugin import InferCheckpoint_io
+ if pretrained_path:
+ from colossalai.inference.core.plugin import InferCheckpoint_io
- # cpt_io = InferCheckpoint_io()
- # if_has_index_file, model_index_file = has_index_file(model_or_path)
- # assert if_has_index_file, "the model path is invalid"
- # cpt_io.load_model(self.model, model_index_file)
+ cpt_io = InferCheckpoint_io()
+ if_has_index_file, model_index_file = has_index_file(pretrained_path)
+ assert if_has_index_file, "the model path is invalid"
+ cpt_io.load_model(self.model, model_index_file)
free_gpu_memory, _ = torch.cuda.mem_get_info()
peak_memory = init_gpu_memory - free_gpu_memory
diff --git a/colossalai/inference/core/rpc_engine.py b/colossalai/inference/core/rpc_engine.py
index 439c4b0b5..87222a744 100644
--- a/colossalai/inference/core/rpc_engine.py
+++ b/colossalai/inference/core/rpc_engine.py
@@ -73,7 +73,9 @@ class RPCInferenceEngine(InferenceEngine):
try:
if isinstance(model_or_path, str):
- self.model_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True)
+ self.model_config = AutoConfig.from_pretrained(
+ model_or_path, trust_remote_code=True, torch_dtype=self.dtype
+ )
elif isinstance(model_or_path, nn.Module):
self.logger.error(
f"An exception occurred during loading model Config: For {__class__.__name__}, we don't support param like nn.Module currently\n"
diff --git a/colossalai/inference/executor/rpc_worker.py b/colossalai/inference/executor/rpc_worker.py
index 913b8667d..a5199cb74 100644
--- a/colossalai/inference/executor/rpc_worker.py
+++ b/colossalai/inference/executor/rpc_worker.py
@@ -18,8 +18,9 @@ from colossalai.inference.modeling.policy import (
model_policy_map,
)
from colossalai.inference.sampler import search_tokens
-from colossalai.inference.utils import get_model_size
+from colossalai.inference.utils import get_model_size, has_index_file
from colossalai.interface import ModelWrapper
+from colossalai.lazy import LazyInitContext
from colossalai.logging import get_dist_logger
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer
@@ -178,20 +179,23 @@ class rpcWorkerService(rpyc.Service):
model_policy (Policy): the policy to replace the model
"""
+ pretrained_path = None
if isinstance(model_or_path, str):
- # is_local = os.path.isdir(model_or_path)
+ import colossalai.interface.pretrained as pretrained_utils
+
try:
- hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True)
+ hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True, torch_dtype=self.dtype)
arch = getattr(hf_config, "architectures")[0]
- # NOTE(lry89757) Currently we load the model using transformers-api,
- # but we will use lazy tensor and checkpoint io to accelerate
- # the model load process in the future.
- model = _SUPPORTED_MODELS[arch].from_pretrained(model_or_path, trust_remote_code=True)
- # if is_local:
- # model = _SUPPORTED_MODELS[arch](hf_config)
- # else:
- # # load the real checkpoint
- # model = _SUPPORTED_MODELS[arch].from_pretrained(model_or_path, trust_remote_code=True)
+ if arch is "BaichuanForCausalLM":
+ self.logger.warning(
+ "Attention ! We use lazy init by default, which could be faster for model loading. For baichuan model, the output maybe have a slight difference with transformers"
+ )
+ ctx = LazyInitContext(default_device="cuda")
+ with ctx:
+ model = _SUPPORTED_MODELS[arch].from_pretrained(
+ model_or_path, trust_remote_code=True, torch_dtype=self.dtype
+ )
+ pretrained_path = pretrained_utils.get_pretrained_path(model)
except Exception as e:
logger.error(
f"An exception occurred during loading model: {e}, model should be loaded by transformers\n"
@@ -240,14 +244,13 @@ class rpcWorkerService(rpyc.Service):
f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}"
)
- # NOTE(lry89757) Deprecated currently, will reused when introduce lazy tensor
- # if isinstance(model_or_path, str) and is_local:
- # from colossalai.inference.core.plugin import InferCheckpoint_io
+ if pretrained_path:
+ from colossalai.inference.core.plugin import InferCheckpoint_io
- # cpt_io = InferCheckpoint_io()
- # if_has_index_file, model_index_file = has_index_file(model_or_path)
- # assert if_has_index_file, "the model path is invalid"
- # cpt_io.load_model(self.model, model_index_file)
+ cpt_io = InferCheckpoint_io()
+ if_has_index_file, model_index_file = has_index_file(pretrained_path)
+ assert if_has_index_file, "the model path is invalid"
+ cpt_io.load_model(self.model, model_index_file)
free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
peak_memory = init_gpu_memory - free_gpu_memory
diff --git a/colossalai/inference/modeling/layers/baichuan_tp_linear.py b/colossalai/inference/modeling/layers/baichuan_tp_linear.py
index 50806a14b..75260f59b 100644
--- a/colossalai/inference/modeling/layers/baichuan_tp_linear.py
+++ b/colossalai/inference/modeling/layers/baichuan_tp_linear.py
@@ -1,8 +1,10 @@
from typing import List, Union
+import torch.distributed as dist
import torch.nn as nn
from torch.distributed import ProcessGroup
+from colossalai.lazy import LazyInitContext
from colossalai.shardformer.layer import Linear1D_Col
from colossalai.shardformer.layer.parallel_module import ParallelModule
@@ -12,17 +14,51 @@ class BaichuanLMHeadLinear1D_Col(Linear1D_Col):
def from_native_module(
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
) -> ParallelModule:
+ LazyInitContext.materialize(module)
module.in_features = module.weight.size(1)
module.out_features = module.weight.size(0)
module.bias = None
module.weight.data = nn.functional.normalize(
module.weight
- ) # TODO(lry89757) This behavior may not apply to lazy init. When we use lazy init, the weight of shardformer is not the real weight.
+ ) # NOTE(lry89757) This behavior may not apply to lazy init. When we use lazy init, the weight of shardformer is not the real weight.
# So we should rewrite our own load_from_state_dict of `BaichuanLMHeadLinear1D_Col` to fix this potential issue.
- return Linear1D_Col.from_native_module(
- module,
- process_group,
- *args,
+ # get the attributes
+ in_features = module.in_features
+ out_features = module.out_features
+ bias = module.bias is not None
+ device = module.weight.device
+ # ensure only one process group is passed
+ if isinstance(process_group, (list, tuple)):
+ assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
+ process_group = process_group[0]
+
+ tp_size = dist.get_world_size(process_group)
+ if out_features < tp_size:
+ return module
+
+ if out_features % tp_size != 0:
+ raise ValueError(
+ f"The size of out_features:{out_features} is not integer multiples of tensor parallel size: {tp_size}!"
+ )
+
+ lmhead_1d = BaichuanLMHeadLinear1D_Col(
+ in_features=in_features,
+ out_features=out_features,
+ bias=bias,
+ device=device,
+ process_group=process_group,
+ weight=module.weight,
+ bias_=module.bias,
**kwargs,
)
+
+ return lmhead_1d
+
+ def _load_from_state_dict(
+ self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
+ ):
+ state_dict[prefix + "weight"] = nn.functional.normalize(state_dict[prefix + "weight"])
+ super()._load_from_state_dict(
+ state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
+ )
diff --git a/colossalai/inference/modeling/models/nopadding_baichuan.py b/colossalai/inference/modeling/models/nopadding_baichuan.py
index 3bab671c4..dfc53d9f6 100644
--- a/colossalai/inference/modeling/models/nopadding_baichuan.py
+++ b/colossalai/inference/modeling/models/nopadding_baichuan.py
@@ -70,7 +70,6 @@ class NopadBaichuanAttention(ParallelModule):
attn_oproj (Linear1D_Row, optional): The Linear1D_Row o_proj. Defaults to None.
"""
ParallelModule.__init__(self)
- self.o_proj = attn_oproj
self.config = config
self.num_heads = num_heads
@@ -78,6 +77,7 @@ class NopadBaichuanAttention(ParallelModule):
self.head_dim = self.hidden_size // self.num_heads
self.process_group = process_group
self.W_pack = W_pack
+ self.o_proj = attn_oproj
self.use_cuda_kernel = model_shard_infer_config.use_cuda_kernel
self.attention_backend = get_attention_backend(model_shard_infer_config)
self.pre_attention_backend = get_pre_attention_backend(model_shard_infer_config)
diff --git a/colossalai/inference/modeling/models/nopadding_llama.py b/colossalai/inference/modeling/models/nopadding_llama.py
index 445ec59ce..c7c7473ac 100644
--- a/colossalai/inference/modeling/models/nopadding_llama.py
+++ b/colossalai/inference/modeling/models/nopadding_llama.py
@@ -284,6 +284,10 @@ class NopadLlamaMLP(LlamaMLP, ParallelModule):
self.gate_up_weight = nn.Parameter(
torch.stack([mlp_gproj_w.transpose(0, 1), mlp_uproj_w.transpose(0, 1)], dim=0)
)
+ self.gate_up_dict = {
+ "gate_proj.weight": None,
+ "up_proj.weight": None,
+ } # used and delattr in load/shard of gate/up weight
self.down_proj = mlp_dproj
self.process_group = process_group
@@ -321,44 +325,47 @@ class NopadLlamaMLP(LlamaMLP, ParallelModule):
):
# NOTE This is a hack to ensure we could load the right weight from LlamaMLP checkpoint due to the use of torch.stack(gate_weight, up_weight)
- for hook in self._load_state_dict_pre_hooks.values():
- hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
+ if hasattr(self, "gate_up_dict"):
+ for hook in self._load_state_dict_pre_hooks.values():
+ hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
- persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
- local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
- local_state = {k: v for k, v in local_name_params if v is not None}
+ persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
+ local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
+ local_state = {k: v for k, v in local_name_params if v is not None}
- key = "gate_up_weight"
- k1 = "gate_proj.weight"
- k2 = "up_proj.weight"
+ device_mesh = self.helper_layout.device_mesh
+ sharding_spec = self.helper_layout.sharding_spec
+ for weight_name in self.gate_up_dict:
+ prefix_weight_name = prefix + weight_name
+ if prefix_weight_name in state_dict.keys():
+ w = distribute_tensor(state_dict[prefix_weight_name], device_mesh, sharding_spec)
+ self.gate_up_dict[weight_name] = w.T
- gate_w = state_dict[prefix + k1]
- up_w = state_dict[prefix + k2]
+ if None not in self.gate_up_dict.values():
+ # we've got all the weights of gate/up
+ gate_up_w = torch.stack(list(self.gate_up_dict.values()), dim=0)
- device_mesh = self.helper_layout.device_mesh
- sharding_spec = self.helper_layout.sharding_spec
- gate_w = distribute_tensor(gate_w, device_mesh, sharding_spec)
- up_w = distribute_tensor(up_w, device_mesh, sharding_spec)
+ input_param = nn.Parameter(
+ gate_up_w
+ ) # NOTE gate_up_weight doesn't have to be a distensor, Like input_param = sharded_tensor_to_param(input_param)
- gate_up_w = torch.stack([gate_w.T, up_w.T], dim=0)
+ key = "gate_up_weight"
+ param = local_state.get(key, None)
- input_param = nn.Parameter(
- gate_up_w
- ) # NOTE gate_up_weight doesn't have to be a distensor, Like input_param = sharded_tensor_to_param(input_param)
- param = local_state[key]
+ try:
+ with torch.no_grad():
+ param.copy_(input_param)
+ except Exception as ex:
+ error_msgs.append(
+ 'While copying the parameter named "{}", '
+ "whose dimensions in the model are {} and "
+ "whose dimensions in the checkpoint are {}, "
+ "an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args)
+ )
- try:
- with torch.no_grad():
- param.copy_(input_param)
- except Exception as ex:
- error_msgs.append(
- 'While copying the parameter named "{}", '
- "whose dimensions in the model are {} and "
- "whose dimensions in the checkpoint are {}, "
- "an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args)
- )
+ del self.gate_up_dict
- strict = False # to avoid unexpected_keys
+ strict = False # to avoid unexpected_keys
super()._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
)
@@ -429,7 +436,15 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule):
self.helper_layout = (
attn_qproj_w.dist_layout
) # NOTE this is a hack for the right load/shard of qkv_weight(used in _load_from_state_dict)
+ self.qkv_dict = {
+ "q_proj.weight": None,
+ "k_proj.weight": None,
+ "v_proj.weight": None,
+ } # used and delattr in load/shard of qkv weight
else:
+ self.helper_layout = (
+ attn_qproj_w.dist_layout
+ ) # NOTE this is a hack for the right load/shard of qkv_weight(used in _load_from_state_dict)
self.q_proj_weight = nn.Parameter(attn_qproj_w.transpose(0, 1).contiguous())
self.k_proj_weight = nn.Parameter(attn_kproj_w.transpose(0, 1).contiguous())
self.v_proj_weight = nn.Parameter(attn_vproj_w.transpose(0, 1).contiguous())
@@ -577,49 +592,83 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule):
def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
):
- if self.num_heads == self.num_key_value_heads:
+ for hook in self._load_state_dict_pre_hooks.values():
+ hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
+
+ persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
+ local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
+ local_state = {k: v for k, v in local_name_params if v is not None}
+
+ device_mesh = self.helper_layout.device_mesh
+ sharding_spec = self.helper_layout.sharding_spec
+
+ if self.num_heads == self.num_key_value_heads and hasattr(self, "qkv_dict"):
# NOTE This is a hack to ensure we could load the right weight from LlamaAttention checkpoint due to the use of torch.stack(q_weight, k_weight, v_weight)
- for hook in self._load_state_dict_pre_hooks.values():
- hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
-
- persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
- local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
- local_state = {k: v for k, v in local_name_params if v is not None}
-
key = "qkv_weight"
- k1 = "q_proj.weight"
- k2 = "k_proj.weight"
- k3 = "v_proj.weight"
- q_w = state_dict[prefix + k1]
- k_w = state_dict[prefix + k2]
- v_w = state_dict[prefix + k3]
- device_mesh = self.helper_layout.device_mesh
- sharding_spec = self.helper_layout.sharding_spec
- q_w = distribute_tensor(q_w, device_mesh, sharding_spec)
- k_w = distribute_tensor(k_w, device_mesh, sharding_spec)
- v_w = distribute_tensor(v_w, device_mesh, sharding_spec)
+ # NOTE(@lry89757) We will load the sharded checkpoint file according to the weight map from *.index.json
+ # Here we need the weight of q,k,v to stack the weights of q,k,v into one qkv weight.
+ # Unfortunately, it is highly like that all weights of q,k,v are not in the same sharded checkpoint file(like meta-llama/llama3-70B)
+ # so here we will stack them when we really collect all the three weights.
+ for weight_name in self.qkv_dict:
+ prefix_weight_name = prefix + weight_name
+ if prefix_weight_name in state_dict.keys():
+ w = distribute_tensor(state_dict[prefix_weight_name], device_mesh, sharding_spec)
+ self.qkv_dict[weight_name] = w.T
- qkv_w = torch.stack([q_w.T, k_w.T, v_w.T], dim=0)
+ if None not in self.qkv_dict.values():
+ # we've got all the weights of q, k, v
+ qkv_w = torch.stack(list(self.qkv_dict.values()), dim=0)
- input_param = nn.Parameter(
- qkv_w
- ) # NOTE qkv_weight doesn't have to be a distensor, Like input_param = sharded_tensor_to_param(input_param)
+ input_param = nn.Parameter(
+ qkv_w
+ ) # NOTE qkv_weight doesn't have to be a distensor, Like input_param = sharded_tensor_to_param(input_param)
- param = local_state[key]
+ param = local_state[key]
- try:
- with torch.no_grad():
- param.copy_(input_param)
- except Exception as ex:
- error_msgs.append(
- 'While copying the parameter named "{}", '
- "whose dimensions in the model are {} and "
- "whose dimensions in the checkpoint are {}, "
- "an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args)
- )
+ try:
+ with torch.no_grad():
+ param.copy_(input_param)
+ except Exception as ex:
+ error_msgs.append(
+ 'While copying the parameter named "{}", '
+ "whose dimensions in the model are {} and "
+ "whose dimensions in the checkpoint are {}, "
+ "an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args)
+ )
- strict = False # to avoid unexpected_keys
+ del self.qkv_dict
+
+ else:
+
+ def _load(origin_weight_name="q_proj.weight", local_weight_name="q_proj_weight"):
+ if prefix + origin_weight_name in state_dict.keys():
+ attn_qproj_w = state_dict[prefix + origin_weight_name]
+ w = distribute_tensor(attn_qproj_w, device_mesh, sharding_spec)
+ input_param = nn.Parameter(w.T)
+ param = local_state[local_weight_name]
+ try:
+ with torch.no_grad():
+ param.copy_(input_param)
+ except Exception as ex:
+ key = local_weight_name
+ error_msgs.append(
+ 'While copying the parameter named "{}", '
+ "whose dimensions in the model are {} and "
+ "whose dimensions in the checkpoint are {}, "
+ "an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args)
+ )
+
+ if prefix + "q_proj.weight" in state_dict.keys():
+ _load(origin_weight_name="q_proj.weight", local_weight_name="q_proj_weight")
+
+ if prefix + "k_proj.weight" in state_dict.keys():
+ _load(origin_weight_name="k_proj.weight", local_weight_name="k_proj_weight")
+
+ if prefix + "v_proj.weight" in state_dict.keys():
+ _load(origin_weight_name="v_proj.weight", local_weight_name="v_proj_weight")
+
+ strict = False # to avoid unexpected_keys
super()._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
)
diff --git a/colossalai/pipeline/p2p.py b/colossalai/pipeline/p2p.py
index 1b55b140c..ed190eb08 100644
--- a/colossalai/pipeline/p2p.py
+++ b/colossalai/pipeline/p2p.py
@@ -225,31 +225,41 @@ def _batch_send_recv_tensor(
send_group: Optional[ProcessGroup],
recv_group: Optional[ProcessGroup],
current_device: Any,
+ overlap_p2p: bool = True,
+ send_first: bool = True,
) -> Optional[Union[torch.Tensor, List[torch.Tensor]]]:
buffer_recv = None
if recv_tensor_metadata is not None:
buffer_recv = _create_recv_buffer(recv_tensor_metadata, current_device)
ops = []
- if send_dst is not None and send_tensor_list is not None:
- assert send_group is not None
- _filling_ops_queue(send_tensor_list, dist.isend, send_dst, ops, send_group)
- if recv_src is not None and buffer_recv is not None:
- assert recv_group is not None
- _filling_ops_queue(buffer_recv, dist.irecv, recv_src, ops, recv_group)
+ is_send = send_dst is not None and send_tensor_list is not None
+ is_recv = recv_src is not None and buffer_recv is not None
+
+ if send_first:
+ if is_send:
+ assert send_group is not None
+ _filling_ops_queue(send_tensor_list, dist.isend, send_dst, ops, send_group)
+ if is_recv:
+ assert recv_group is not None
+ _filling_ops_queue(buffer_recv, dist.irecv, recv_src, ops, recv_group)
+ else:
+ if is_recv:
+ assert recv_group is not None
+ _filling_ops_queue(buffer_recv, dist.irecv, recv_src, ops, recv_group)
+ if is_send:
+ assert send_group is not None
+ _filling_ops_queue(send_tensor_list, dist.isend, send_dst, ops, send_group)
if len(ops) > 0:
reqs = dist.batch_isend_irecv(ops)
- for req in reqs:
- req.wait()
-
- # Remove synchronization according to Pytorch's documentation
- # However, the Megatron-LM does synchronization here
- # https://github.com/microsoft/Megatron-DeepSpeed/blob/ef13d099c2a1609225a4ce4c1a1753cc76dd90a1/megatron/p2p_communication.py#L111-L112
- # In case there is potential error, uncomment the following `torch.cuda.synchronize()`
- # torch.cuda.synchronize()
-
- return buffer_recv
+ if not overlap_p2p:
+ for req in reqs:
+ req.wait()
+ return buffer_recv, []
+ else:
+ return buffer_recv, reqs
+ return None, []
def _send_recv_serialization_object(
@@ -260,10 +270,11 @@ def _send_recv_serialization_object(
recv_group: Optional[ProcessGroup],
current_device: Any,
is_nccl_backend: bool,
+ send_first: bool = True,
) -> Optional[P2PMetadata]:
ops = []
-
send_object_tensor = None
+ send_object_size_tensor = None
if object is not None and send_dst is not None:
if Version(torch.__version__) >= Version("1.13.0"):
send_object_tensor, send_object_size_tensor = c10d._object_to_tensor(object, device=current_device)
@@ -274,43 +285,54 @@ def _send_recv_serialization_object(
send_object_size_tensor = send_object_size_tensor.to(current_device)
send_object_tensor = send_object_tensor.to(current_device)
- _filling_ops_queue(send_object_size_tensor, dist.isend, send_dst, ops, send_group)
-
recv_object_size_tensor = None
if recv_src is not None:
recv_object_size_tensor = torch.empty(1, dtype=torch.long)
if is_nccl_backend:
recv_object_size_tensor = recv_object_size_tensor.to(current_device)
- _filling_ops_queue(recv_object_size_tensor, dist.irecv, recv_src, ops, recv_group)
+
+ if send_first:
+ if send_object_size_tensor is not None:
+ _filling_ops_queue(send_object_size_tensor, dist.isend, send_dst, ops, send_group)
+ if recv_src is not None:
+ _filling_ops_queue(recv_object_size_tensor, dist.irecv, recv_src, ops, recv_group)
+ else:
+ if recv_src is not None:
+ _filling_ops_queue(recv_object_size_tensor, dist.irecv, recv_src, ops, recv_group)
+ if send_object_size_tensor is not None:
+ _filling_ops_queue(send_object_size_tensor, dist.isend, send_dst, ops, send_group)
if len(ops) > 0:
reqs = dist.batch_isend_irecv(ops)
for req in reqs:
- req.wait()
-
- # See the comment in `_batch_send_recv_tensor`
- # torch.cuda.synchronize()
+ req.wait() # This blocks the compute stream in torch
ops = []
-
- if send_dst is not None and send_object_tensor is not None:
- _filling_ops_queue(send_object_tensor, dist.isend, send_dst, ops, send_group)
+ is_send = send_dst is not None and send_object_tensor is not None
+ is_recv = recv_src is not None and recv_object_size_tensor is not None
recv_object_tensor = None
- if recv_src is not None and recv_object_size_tensor is not None:
+ if is_recv:
recv_object_tensor = torch.empty(recv_object_size_tensor.item(), dtype=torch.uint8)
if is_nccl_backend:
recv_object_tensor = recv_object_tensor.to(current_device)
- _filling_ops_queue(recv_object_tensor, dist.irecv, recv_src, ops, recv_group)
+
+ if send_first:
+ if is_send:
+ _filling_ops_queue(send_object_tensor, dist.isend, send_dst, ops, send_group)
+ if is_recv:
+ _filling_ops_queue(recv_object_tensor, dist.irecv, recv_src, ops, recv_group)
+ else:
+ if is_recv:
+ _filling_ops_queue(recv_object_tensor, dist.irecv, recv_src, ops, recv_group)
+ if is_send:
+ _filling_ops_queue(send_object_tensor, dist.isend, send_dst, ops, send_group)
if len(ops) > 0:
reqs = dist.batch_isend_irecv(ops)
for req in reqs:
req.wait()
- # See the comment in `_batch_send_recv_tensor`
- # torch.cuda.synchronize()
-
if recv_object_tensor is not None and recv_object_size_tensor is not None:
recv_object_tensor = recv_object_tensor.type(torch.uint8)
if recv_object_tensor.device != torch.device("cpu"):
@@ -328,11 +350,12 @@ def _communicate(
object: Any,
send_dst: Optional[int],
recv_src: Optional[int],
+ overlap_p2p: bool,
send_group: Optional[ProcessGroup] = None,
recv_group: Optional[ProcessGroup] = None,
send_metadata: bool = True,
metadata_recv: Optional[P2PMetadata] = None,
- send_prior_fallback: Optional[bool] = None,
+ send_first: Optional[bool] = None,
) -> Any:
"""
Send and receive object from send_dst and recv_src respectively
@@ -341,6 +364,7 @@ def _communicate(
object (Any): object needed to be sent
send_dst (int): rank of the destination
recv_src (int): rank of the source
+ overlap_p2p (bool): whether to overlap p2p communication with computation
send_group (ProcessGroup, optional): process group of sender
recv_group (ProcessGroup, optional): process group of receiver
send_metadata (bool, optional): whether to send metadata
@@ -358,32 +382,10 @@ def _communicate(
# NOTE: if object contains non-tensor objects, we have to send metadata
metadata_send, tensor_objs = create_send_metadata(object, strict=False, return_tensor=True)
send_metadata = send_metadata or len(metadata_send.non_tensor_obj_idx) > 0
+ else:
+ send_metadata = False
- # NOTE: send & recv should be atomic operations. However, if we need to send metadata or receive metadata,
- # we are not able to do that (1. send & recv metadata 2. send & recv). So we need to split the send & recv into two parts in this case.
- if (send_dst is not None and recv_src is not None) and (send_metadata or metadata_recv is None):
- assert send_prior_fallback is not None, "Priority must be set if fallback happens"
- if send_prior_fallback:
- _communicate(object, send_dst=send_dst, recv_src=None, send_group=send_group, send_metadata=send_metadata)
- return _communicate(
- None, send_dst=None, recv_src=recv_src, recv_group=recv_group, metadata_recv=metadata_recv
- )
- else:
- recv_data = _communicate(
- None, send_dst=None, recv_src=recv_src, recv_group=recv_group, metadata_recv=metadata_recv
- )
- _communicate(object, send_dst=send_dst, recv_src=None, send_group=send_group, send_metadata=send_metadata)
- return recv_data
-
- # NOTE: only the following 5 cases are valid:
- # 1. send() [needs extra metadata] and no recv()
- # 2. recv() [needs extra metadata] and no send()
- # 3. neither send() nor recv() need extra metadata
- assert not (send_dst is not None and send_metadata) or recv_src is None
- assert not (recv_src is not None and metadata_recv is None) or send_dst is None
- assert not (send_dst is not None and recv_src is not None) or (not send_metadata and metadata_recv is not None)
assert not c10d._rank_not_in_group(send_group) and not c10d._rank_not_in_group(recv_group)
-
current_send_device, is_send_nccl_backend = _check_device(send_group)
current_recv_device, is_recv_nccl_backend = _check_device(recv_group)
@@ -402,14 +404,25 @@ def _communicate(
recv_group=recv_group if metadata_recv is None else None,
current_device=current_device,
is_nccl_backend=is_nccl_backend,
+ send_first=send_first if send_first != None else True,
)
- assert metadata_recv is None or _metadata_recv is None
+ assert (
+ metadata_recv is None or _metadata_recv is None
+ ), "You shouldn't receive metadata when using the cached metadata"
metadata_recv = _metadata_recv if metadata_recv is None else metadata_recv
# Send and receive data
recv_tensor_metadata = None if metadata_recv is None else metadata_recv.tensor_metadata
- recv_tensor_objs = _batch_send_recv_tensor(
- tensor_objs, recv_tensor_metadata, send_dst, recv_src, send_group, recv_group, current_device
+ recv_tensor_objs, wait_handles = _batch_send_recv_tensor(
+ tensor_objs,
+ recv_tensor_metadata,
+ send_dst,
+ recv_src,
+ send_group,
+ recv_group,
+ current_device,
+ overlap_p2p=overlap_p2p,
+ send_first=send_first if send_first != None else True,
)
if metadata_recv is not None:
@@ -424,33 +437,9 @@ def _communicate(
for idx in non_tensor_obj_idx:
recv_tensor_objs.insert(idx, non_tensor_objs.pop(0))
recv_object = tree_unflatten(recv_tensor_objs, tree_spec)
+ return recv_object, wait_handles
- return recv_object
-
-
-def _send_object(object: Any, src: int, dst: int, group: ProcessGroup, **kwargs) -> None:
- """send anything to dst rank
-
- Args:
- object (Any): object needed to be sent
- dst (int): rank of the destination
-
- Returns:
- None
- """
- _communicate(object, send_dst=dst, recv_src=None, send_group=group, **kwargs)
-
-
-def _recv_object(src: int, dst: int, group: ProcessGroup, **kwargs) -> Any:
- """recv anything from src
-
- Args:
- src (int): source rank of data. local rank will receive data from src rank.
-
- Returns:
- Any: Object received from src.
- """
- return _communicate(None, send_dst=None, recv_src=src, recv_group=group, **kwargs)
+ return None, wait_handles
def _p2p_comm(
@@ -532,10 +521,13 @@ def _p2p_comm(
class PipelineP2PCommunication:
- def __init__(self, stage_manager: PipelineStageManager) -> None:
+ def __init__(self, stage_manager: PipelineStageManager, overlap_p2p: bool = True) -> None:
self.stage_manager = stage_manager
+ self.overlap_p2p = overlap_p2p
- def recv_forward(self, prev_rank: Optional[int] = None, metadata_recv: Optional[P2PMetadata] = None) -> Any:
+ def recv_forward(
+ self, prev_rank: Optional[int] = None, metadata_recv: Optional[P2PMetadata] = None
+ ) -> Tuple[Any, List]:
"""Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
Args:
@@ -543,95 +535,186 @@ class PipelineP2PCommunication:
Returns:
Any: The input tensor or input tensor list.
+ List: List of handles for the communication requests, if overlap is enabled.
"""
if prev_rank is None:
prev_rank = self.stage_manager.get_prev_rank()
- cur_rank = self.stage_manager.get_rank()
- input_tensor = _recv_object(
- prev_rank,
- cur_rank,
- self.stage_manager.get_p2p_process_group(prev_rank, cur_rank),
+ input_tensor, wait_handles = _communicate(
+ object=None,
+ recv_src=prev_rank,
+ send_dst=None,
+ recv_group=self.stage_manager.get_p2p_process_group(),
metadata_recv=metadata_recv,
+ overlap_p2p=self.overlap_p2p,
)
- return input_tensor
+ return input_tensor, wait_handles
- def recv_backward(self, next_rank: Optional[int] = None, metadata_recv: Optional[P2PMetadata] = None) -> Any:
+ def recv_backward(
+ self, next_rank: Optional[int] = None, metadata_recv: Optional[P2PMetadata] = None
+ ) -> Tuple[Any, List]:
"""Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
-
Args:
next_rank (int, optional): The rank of the source of the tensor.
Returns:
- Any: The input gradient tensor or gradient tensor list.
+ Any: The input tensor or input tensor list.
+ List: List of handles for the communication requests, if overlap is enabled.
"""
if next_rank is None:
next_rank = self.stage_manager.get_next_rank()
- cur_rank = self.stage_manager.get_rank()
- output_tensor_grad = _recv_object(
- next_rank,
- cur_rank,
- self.stage_manager.get_p2p_process_group(next_rank, cur_rank),
+
+ output_tensor_grad, wait_handles = _communicate(
+ object=None,
+ recv_src=next_rank,
+ send_dst=None,
+ recv_group=self.stage_manager.get_p2p_process_group(),
metadata_recv=metadata_recv,
+ overlap_p2p=self.overlap_p2p,
)
- return output_tensor_grad
+ return output_tensor_grad, wait_handles
- def send_forward(self, output_object: Any, next_rank: Optional[int] = None, send_metadata: bool = True) -> None:
+ def send_forward(self, output_object: Any, next_rank: Optional[int] = None, send_metadata: bool = True) -> List:
"""Sends the input tensor to the next stage in pipeline.
Args:
output_object (Any): Object to be sent.
next_rank (int, optional): The rank of the recipient of the tensor.
+
+ Returns:
+ List: List of handles for the communication requests, if overlap is enabled.
"""
if next_rank is None:
next_rank = self.stage_manager.get_next_rank()
- cur_rank = self.stage_manager.get_rank()
- _send_object(
+ _, handles = _communicate(
output_object,
- cur_rank,
- next_rank,
- self.stage_manager.get_p2p_process_group(cur_rank, next_rank),
+ recv_src=None,
+ send_dst=next_rank,
+ send_group=self.stage_manager.get_p2p_process_group(),
send_metadata=send_metadata,
+ overlap_p2p=self.overlap_p2p,
)
+ return handles
- def send_backward(self, input_object: Any, prev_rank: Optional[int] = None, send_metadata: bool = True) -> None:
+ def send_backward(self, input_object: Any, prev_rank: Optional[int] = None, send_metadata: bool = True) -> List:
"""Sends the gradient tensor to the previous stage in pipeline.
Args:
input_object (Any): Object to be sent.
prev_rank (int, optional): The rank of the recipient of the tensor
+
+ Returns:
+ List: List of handles for the communication requests, if overlap is enabled.
"""
if prev_rank is None:
prev_rank = self.stage_manager.get_prev_rank()
- cur_rank = self.stage_manager.get_rank()
- _send_object(
+ _, handles = _communicate(
input_object,
- cur_rank,
- prev_rank,
- self.stage_manager.get_p2p_process_group(cur_rank, prev_rank),
+ recv_src=None,
+ send_dst=prev_rank,
+ send_group=self.stage_manager.get_p2p_process_group(),
send_metadata=send_metadata,
+ overlap_p2p=self.overlap_p2p,
+ )
+ return handles
+
+ def send_forward_recv_forward(
+ self,
+ output_object: Any,
+ is_send: bool,
+ is_recv: bool,
+ send_first: bool,
+ send_metadata: bool = True,
+ metadata_recv: Optional[P2PMetadata] = None,
+ ) -> Tuple[Any, List]:
+ """Sends the input tensor to the next pipeline stage and copy the output tensor from the next pipeline stage
+
+ Args:
+ output_object (Any): Object to be sent.
+ is_send (bool): Whether to send the input tensor to the next pipeline stage.
+ is_recv (bool): Whether to copy the output tensor from the next pipeline stage.
+ send_first (bool): Whether to send before receive.
+ send_metadata (bool, optional): Whether to send metadata.
+ metadata_recv (P2PMetadata, optional): The cached metadata(size, type) of the object to be received.
+
+ Returns:
+ Any: The input tensor or input tensor list.
+ List: List of handles for the communication requests, if overlap is enabled.
+ """
+ next_rank = self.stage_manager.get_next_rank() if is_send else None
+ prev_rank = self.stage_manager.get_prev_rank() if is_recv else None
+ group = self.stage_manager.get_p2p_process_group()
+ return _communicate(
+ output_object,
+ send_dst=next_rank,
+ recv_src=prev_rank,
+ send_group=group if is_send else None,
+ recv_group=group if is_recv else None,
+ send_metadata=send_metadata if is_send else False,
+ metadata_recv=metadata_recv if is_recv else None,
+ send_first=send_first,
+ overlap_p2p=self.overlap_p2p,
+ )
+
+ def send_backward_recv_backward(
+ self,
+ input_object: Any,
+ is_send: bool,
+ is_recv: bool,
+ send_first: bool,
+ send_metadata: bool = True,
+ metadata_recv: Optional[P2PMetadata] = None,
+ ) -> Tuple[Any, List]:
+ """Sends the gradient tensor to the previous pipeline stage and copy the gradient tensor from the previous pipeline stage
+
+ Args:
+ input_object (Any): Object to be sent.
+ is_send (bool): Whether to send the gradient tensor to the previous pipeline stage.
+ is_recv (bool): Whether to copy the gradient tensor from the previous pipeline stage.
+ send_first (bool): Whether to send before receive.
+ send_metadata (bool, optional): Whether to send metadata.
+ metadata_recv (P2PMetadata, optional): The cached metadata(size, type) of the object to be received.
+
+ Returns:
+ Any: The input tensor or input tensor list.
+ List: List of handles for the communication requests, if overlap is enabled.
+ """
+ prev_rank = self.stage_manager.get_prev_rank() if is_send else None
+ next_rank = self.stage_manager.get_next_rank() if is_recv else None
+
+ group = self.stage_manager.get_p2p_process_group()
+
+ return _communicate(
+ input_object,
+ send_dst=prev_rank,
+ recv_src=next_rank,
+ send_group=group if is_send else None,
+ recv_group=group if is_recv else None,
+ send_metadata=send_metadata if is_send else False,
+ metadata_recv=metadata_recv if is_recv else None,
+ send_first=send_first,
+ overlap_p2p=self.overlap_p2p,
)
def send_forward_recv_backward(
self,
input_object: Any,
- next_rank: Optional[int] = None,
send_metadata: bool = True,
metadata_recv: Optional[P2PMetadata] = None,
- send_prior_fallback: Optional[bool] = None,
- ) -> Any:
- """Sends the gradient tensor to and copy the gradient tensor from the next stage in pipeline
+ send_first: Optional[bool] = None,
+ ) -> Tuple[Any, List]:
+ """Sends the gradient tensor to and copy the gradient tensor from the next pipeline stage
Args:
input_object (Any): Object to be sent.
- next_rank (int, optional): The rank of the sender and recipient of the tensor
- """
- if next_rank is None:
- next_rank = self.stage_manager.get_next_rank()
- cur_rank = self.stage_manager.get_rank()
- group = self.stage_manager.get_p2p_process_group(cur_rank, next_rank)
+ Returns:
+ Any: The input tensor or input tensor list.
+ List: List of handles for the communication requests, if overlap is enabled.
+ """
+ next_rank = self.stage_manager.get_next_rank()
+ group = self.stage_manager.get_p2p_process_group()
return _communicate(
input_object,
next_rank,
@@ -640,28 +723,28 @@ class PipelineP2PCommunication:
recv_group=group,
send_metadata=send_metadata,
metadata_recv=metadata_recv,
- send_prior_fallback=send_prior_fallback,
+ send_first=send_first,
+ overlap_p2p=False,
)
def send_backward_recv_forward(
self,
input_object: Any,
- prev_rank: Optional[int] = None,
send_metadata: bool = True,
metadata_recv: Optional[P2PMetadata] = None,
- send_prior_fallback: Optional[bool] = None,
- ) -> Any:
+ send_first: Optional[bool] = None,
+ ) -> Tuple[Any, List]:
"""Sends the gradient tensor to and copy the gradient tensor from the previous stage in pipeline
Args:
input_object (Any): Object to be sent.
- prev_rank (int, optional): The rank of the sender and recipient of the tensor
- """
- if prev_rank is None:
- prev_rank = self.stage_manager.get_prev_rank()
- cur_rank = self.stage_manager.get_rank()
- group = self.stage_manager.get_p2p_process_group(prev_rank, cur_rank)
+ Returns:
+ Any: The input tensor or input tensor list.
+ List: List of handles for the communication requests, if overlap is enabled.
+ """
+ prev_rank = self.stage_manager.get_prev_rank()
+ group = self.stage_manager.get_p2p_process_group()
return _communicate(
input_object,
prev_rank,
@@ -670,7 +753,8 @@ class PipelineP2PCommunication:
recv_group=group,
send_metadata=send_metadata,
metadata_recv=metadata_recv,
- send_prior_fallback=send_prior_fallback,
+ send_first=send_first,
+ overlap_p2p=False,
)
def p2p_communicate(
@@ -679,7 +763,7 @@ class PipelineP2PCommunication:
recv_pre: bool,
next_rank: Optional[int] = None,
comm_dtype: torch.dtype = torch.float16,
- ) -> None:
+ ) -> Any:
"""
Sends the input tensor to the next stage in pipeline, using `P2Pop` in torch.
@@ -689,12 +773,11 @@ class PipelineP2PCommunication:
"""
if next_rank is None:
next_rank = self.stage_manager.get_next_rank()
- cur_rank = self.stage_manager.get_rank()
recv_tensor = _p2p_comm(
output_object,
recv_pre,
next_rank,
- self.stage_manager.get_p2p_process_group(cur_rank, next_rank),
+ self.stage_manager.get_p2p_process_group(),
comm_dtype,
)
return recv_tensor
diff --git a/colossalai/pipeline/schedule/interleaved_pp.py b/colossalai/pipeline/schedule/interleaved_pp.py
index a4ace5e1b..a21b45c44 100644
--- a/colossalai/pipeline/schedule/interleaved_pp.py
+++ b/colossalai/pipeline/schedule/interleaved_pp.py
@@ -1,8 +1,9 @@
from functools import partial
-from typing import Any, Callable, Dict, Iterable, List, Optional, Union
+from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
import torch
import torch.cuda
+import torch.distributed
from torch.nn import Module, ModuleList
from torch.utils._pytree import tree_map
@@ -16,6 +17,12 @@ from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_
from .base import PipelineSchedule
+def _wait_p2p(wait_handles: List[torch.cuda.Event]) -> None:
+ if wait_handles is not None:
+ for req in wait_handles:
+ req.wait()
+
+
class InterleavedSchedule(PipelineSchedule):
def __init__(
self,
@@ -24,13 +31,15 @@ class InterleavedSchedule(PipelineSchedule):
num_microbatch: Optional[int] = None,
microbatch_size: Optional[int] = None,
enable_metadata_cache: bool = True,
+ overlap_p2p: bool = True,
) -> None:
super().__init__(stage_manager)
assert (
num_microbatch is not None or microbatch_size is not None
), "Either num_microbatch or microbatch_size should be provided"
- self.comm = PipelineP2PCommunication(stage_manager)
+ self.comm = PipelineP2PCommunication(stage_manager, overlap_p2p=overlap_p2p)
+ self.overlap_p2p = overlap_p2p
self.num_microbatch = num_microbatch
self.microbatch_size = microbatch_size
self.num_model_chunks = num_model_chunks
@@ -113,14 +122,17 @@ class InterleavedSchedule(PipelineSchedule):
Returns:
int: The model chunk idx of the input microbatch_id
"""
- assert microbatch_id < self.num_microbatch * self.num_model_chunks
+ assert (
+ microbatch_id < self.num_microbatch * self.num_model_chunks
+ ), f"microbatch_id {microbatch_id} is out of range ({self.num_microbatch * self.num_model_chunks})"
microbatch_id_in_group = microbatch_id % (self.stage_manager.num_stages * self.num_model_chunks)
model_chunk_id = microbatch_id_in_group // self.stage_manager.num_stages
if not is_forward:
+ # Reverse order
model_chunk_id = self.num_model_chunks - model_chunk_id - 1
return model_chunk_id
- def recv_forward(self, model_chunk_id: int, prev_rank: int = None) -> Any:
+ def recv_forward(self, model_chunk_id: int, prev_rank: int = None) -> Tuple[Any, List]:
"""Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
For interleaved 1F1B.
@@ -130,16 +142,19 @@ class InterleavedSchedule(PipelineSchedule):
Returns:
Any: The input tensor or input tensor list.
+ Any: The wait handles for the communication.
"""
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
if not self.stage_manager.is_first_stage():
- input_tensor = self.comm.recv_forward(prev_rank, metadata_recv=self.tensor_metadata_recv)
+ input_tensor, wait_handles = self.comm.recv_forward(prev_rank, metadata_recv=self.tensor_metadata_recv)
+
if self.enable_metadata_cache and self.tensor_metadata_recv is None:
self.tensor_metadata_recv = create_send_metadata(input_tensor)
- return input_tensor
+ return input_tensor, wait_handles
+ return None, []
- def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> Any:
+ def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> Tuple[Any, List]:
"""Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
For interleaved 1F1B.
@@ -149,16 +164,20 @@ class InterleavedSchedule(PipelineSchedule):
Returns:
Any: The input gradient tensor or gradient tensor list.
+ Any: The wait handles for the communication.
"""
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
if not self.stage_manager.is_last_stage():
- output_tensor_grad = self.comm.recv_backward(next_rank, metadata_recv=self.grad_metadata_recv)
+ output_tensor_grad, wait_handles = self.comm.recv_backward(
+ next_rank, metadata_recv=self.grad_metadata_recv
+ )
if self.enable_metadata_cache and self.grad_metadata_recv is None:
self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
+ return output_tensor_grad, wait_handles
- return output_tensor_grad
+ return None, []
- def send_forward(self, model_chunk_id: int, output_tensor: Any, next_rank: int = None) -> None:
+ def send_forward(self, model_chunk_id: int, output_tensor: Any, next_rank: int = None) -> List:
"""Sends the input tensor to the next stage in pipeline.
For interleaved 1F1B.
@@ -166,13 +185,18 @@ class InterleavedSchedule(PipelineSchedule):
model_chunk_id (int): The current model chunk idx.
output_object (Any): Object to be sent.
next_rank (int, optional): The rank of the recipient of the tensor.
+
+ Returns:
+ Any: The wait handles for the communication.
"""
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
if not self.stage_manager.is_last_stage():
- self.comm.send_forward(output_tensor, next_rank, send_metadata=self.send_tensor_metadata)
+ send_handles = self.comm.send_forward(output_tensor, next_rank, send_metadata=self.send_tensor_metadata)
self.send_tensor_metadata = not self.enable_metadata_cache
+ return send_handles
+ return []
- def send_backward(self, model_chunk_id: int, input_tensor_grad: Any, prev_rank: int = None) -> None:
+ def send_backward(self, model_chunk_id: int, input_tensor_grad: Any, prev_rank: int = None) -> List:
"""Sends the gradient tensor to the previous stage in pipeline.
For interleaved 1F1B.
@@ -180,99 +204,61 @@ class InterleavedSchedule(PipelineSchedule):
model_chunk_id (int): The current model chunk idx.
input_object (Any): Object to be sent.
prev_rank (int, optional): The rank of the recipient of the tensor
+
+ Returns:
+ Any: The wait handles for the communication.
"""
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
if not self.stage_manager.is_first_stage():
- self.comm.send_backward(input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata)
+ send_handles = self.comm.send_backward(
+ input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata
+ )
self.send_grad_metadata = not self.enable_metadata_cache
-
- def send_forward_recv_backward(
- self,
- model_chunk_id_send: int,
- model_chunk_id_recv: int,
- output_tensor: Any,
- next_rank: Optional[int] = None,
- send_prior_fallback: Optional[bool] = None,
- ) -> Any:
- with self.stage_manager.switch_model_chunk_id(model_chunk_id_send):
- send_data = not self.stage_manager.is_last_stage()
- with self.stage_manager.switch_model_chunk_id(model_chunk_id_recv):
- recv_data = not self.stage_manager.is_last_stage()
-
- if send_data and recv_data:
- if not self.send_forward_recv_backward and self.grad_metadata_recv is not None:
- send_prior_fallback = None # must not fallback
- output_tensor_grad = self.comm.send_forward_recv_backward(
- output_tensor,
- next_rank,
- send_metadata=self.send_tensor_metadata,
- metadata_recv=self.grad_metadata_recv,
- send_prior_fallback=send_prior_fallback,
- )
- self.send_tensor_metadata = not self.enable_metadata_cache
- if self.enable_metadata_cache and self.grad_metadata_recv is None:
- self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
- return output_tensor_grad
-
- # send only or recv only
- self.send_forward(model_chunk_id_send, output_tensor)
- return self.recv_backward(model_chunk_id_recv)
-
- def send_backward_recv_forward(
- self,
- model_chunk_id_send: int,
- model_chunk_id_recv: int,
- input_tensor_grad: Any,
- prev_rank: Optional[int] = None,
- send_prior_fallback: Optional[bool] = None,
- ) -> Any:
- with self.stage_manager.switch_model_chunk_id(model_chunk_id_send):
- send_data = not self.stage_manager.is_first_stage()
- with self.stage_manager.switch_model_chunk_id(model_chunk_id_recv):
- recv_data = not self.stage_manager.is_first_stage()
-
- if send_data and recv_data:
- if not self.send_backward_recv_backward and self.tensor_metadata_recv is not None:
- send_prior_fallback = None # must not fallback
- input_tensor = self.comm.send_backward_recv_forward(
- input_tensor_grad,
- prev_rank,
- send_metadata=self.send_grad_metadata,
- metadata_recv=self.tensor_metadata_recv,
- send_prior_fallback=send_prior_fallback,
- )
- self.send_grad_metadata = not self.enable_metadata_cache
- if self.enable_metadata_cache and self.tensor_metadata_recv is None:
- self.tensor_metadata_recv = create_send_metadata(input_tensor)
- return input_tensor
-
- # send only or recv only
- self.send_backward(model_chunk_id_send, input_tensor_grad)
- return self.recv_forward(model_chunk_id_recv)
+ return send_handles
+ return []
def send_forward_recv_forward(
- self, model_chunk_id_send: int, model_chunk_id_recv: int, output_tensor: Any, send_prior: bool
- ):
- if send_prior:
- self.send_forward(model_chunk_id_send, output_tensor)
- input_tensor = self.recv_forward(model_chunk_id_recv)
- else:
- input_tensor = self.recv_forward(model_chunk_id_recv)
- self.send_forward(model_chunk_id_send, output_tensor)
+ self, model_chunk_id_send: int, model_chunk_id_recv: int, output_tensor: Any, send_first: bool = True
+ ) -> Tuple[Any, List]:
+ with self.stage_manager.switch_model_chunk_id(model_chunk_id_send):
+ is_send = not self.stage_manager.is_last_stage()
+ with self.stage_manager.switch_model_chunk_id(model_chunk_id_recv):
+ is_recv = not self.stage_manager.is_first_stage()
+ input_tensor, wait_handles = self.comm.send_forward_recv_forward(
+ output_tensor,
+ is_send,
+ is_recv,
+ send_metadata=self.send_tensor_metadata,
+ metadata_recv=self.tensor_metadata_recv,
+ send_first=send_first,
+ )
+ # Cache metadata
+ self.send_tensor_metadata = not self.enable_metadata_cache and is_send
+ if is_recv and self.enable_metadata_cache and self.tensor_metadata_recv is None:
+ self.tensor_metadata_recv = create_send_metadata(input_tensor)
- return input_tensor
+ return input_tensor, wait_handles
def send_backward_recv_backward(
- self, model_chunk_id_send: int, model_chunk_id_recv: int, input_tensor_grad: Any, send_prior: bool
- ):
- if send_prior:
- self.send_backward(model_chunk_id_send, input_tensor_grad)
- output_tensor_grad = self.recv_backward(model_chunk_id_recv)
- else:
- output_tensor_grad = self.recv_backward(model_chunk_id_recv)
- self.send_backward(model_chunk_id_send, input_tensor_grad)
-
- return output_tensor_grad
+ self, model_chunk_id_send: int, model_chunk_id_recv: int, input_tensor_grad: Any, send_first: bool = True
+ ) -> Tuple[Any, List]:
+ with self.stage_manager.switch_model_chunk_id(model_chunk_id_send):
+ is_send = not self.stage_manager.is_first_stage()
+ with self.stage_manager.switch_model_chunk_id(model_chunk_id_recv):
+ is_recv = not self.stage_manager.is_last_stage()
+ output_tensor_grad, wait_handles = self.comm.send_backward_recv_backward(
+ input_tensor_grad,
+ is_send,
+ is_recv,
+ send_metadata=self.send_grad_metadata,
+ metadata_recv=self.grad_metadata_recv,
+ send_first=send_first,
+ )
+ # Cache metadata
+ self.send_grad_metadata = not self.enable_metadata_cache and is_send
+ if is_recv and self.enable_metadata_cache and self.grad_metadata_recv is None:
+ self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
+ return output_tensor_grad, wait_handles
def forward_step(
self,
@@ -294,10 +280,12 @@ class InterleavedSchedule(PipelineSchedule):
Returns:
Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor).
"""
+ # Load input ids, attention mask and labels
micro_batch = self.load_micro_batch(model_chunk_id=model_chunk_id)
# for the first stage, input_obj is None
- # for the non-first stage, input_obj is the output of the previous stage and it's must be a dict
+ # for other stages, input_obj is the output of the previous stage containing hidden_states etc.
+ # Only attention_mask from micro_batch is used
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
if isinstance(model_chunk, ModuleList):
@@ -381,23 +369,27 @@ class InterleavedSchedule(PipelineSchedule):
if return_loss and self.stage_manager.is_last_stage(ignore_chunk=True):
accum_loss = torch.scalar_tensor(0, device=get_current_device())
+ fwd_wait_handles = []
model_chunk_id = self.get_model_chunk_id(0, is_forward=True)
- input_obj = self.recv_forward(model_chunk_id)
+ input_obj, fwd_wait_handles = self.recv_forward(model_chunk_id)
for i in range(self.num_microbatch * self.num_model_chunks):
- last_iteration = i == self.num_microbatch * self.num_model_chunks - 1
+ last_batch = i == self.num_microbatch * self.num_model_chunks - 1
model_chunk_id = self.get_model_chunk_id(i, is_forward=True)
+
+ # Wait until current input is received
+ _wait_p2p(fwd_wait_handles)
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
- if not last_iteration:
- input_obj = self.send_forward_recv_forward(
+ if not last_batch:
+ input_obj, fwd_wait_handles = self.send_forward_recv_forward(
model_chunk_id_send=model_chunk_id,
model_chunk_id_recv=self.get_model_chunk_id(i + 1, is_forward=True),
output_tensor=output_obj,
- send_prior=self.stage_manager.stage % 2 == 0,
+ send_first=self.stage_manager.stage % 2 == 0,
)
else:
- self.send_forward(model_chunk_id, output_obj)
+ fwd_wait_handles = self.send_forward(model_chunk_id, output_obj)
if outputs is not None:
outputs = merge_batch(outputs)
@@ -420,7 +412,9 @@ class InterleavedSchedule(PipelineSchedule):
self.load_batch(data_iter)
num_microbatch = self.num_microbatch * self.num_model_chunks
+ # Forward + until 1st backward
num_warmup_microbatch = (self.stage_manager.num_stages - self.stage_manager.stage - 1) * 2
+ # Steps needed to reach the last chunk
num_warmup_microbatch += (self.num_model_chunks - 1) * self.stage_manager.num_stages
num_warmup_microbatch = min(num_warmup_microbatch, num_microbatch)
num_microbatch_remaining = num_microbatch - num_warmup_microbatch
@@ -435,35 +429,44 @@ class InterleavedSchedule(PipelineSchedule):
if return_loss and self.stage_manager.is_last_stage(ignore_chunk=True):
accum_loss = torch.scalar_tensor(0, device=get_current_device())
+ bwd_wait_handles = []
+ # Get the 1st input batch
model_chunk_id = self.get_model_chunk_id(0, is_forward=True)
- input_obj = self.recv_forward(model_chunk_id)
+ input_obj, fwd_wait_handles = self.recv_forward(model_chunk_id)
+
# Run warmup forward passes.
for i in range(num_warmup_microbatch):
- last_iteration = i == num_warmup_microbatch - 1
+ last_batch = i == num_warmup_microbatch - 1
model_chunk_id = self.get_model_chunk_id(i, is_forward=True)
+
+ # Wait for input
+ _wait_p2p(fwd_wait_handles)
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
input_objs[model_chunk_id].append(input_obj)
output_objs[model_chunk_id].append(output_obj)
- if last_iteration and num_microbatch_remaining == 0:
- self.send_forward(model_chunk_id, output_obj)
+ if last_batch and num_microbatch_remaining == 0:
+ fwd_wait_handles = self.send_forward(model_chunk_id, output_obj)
else:
- input_obj = self.send_forward_recv_forward(
+ input_obj, fwd_wait_handles = self.send_forward_recv_forward(
model_chunk_id_send=model_chunk_id,
model_chunk_id_recv=self.get_model_chunk_id(i + 1, is_forward=True),
output_tensor=output_obj,
- send_prior=self.stage_manager.stage % 2 == 0,
+ send_first=self.stage_manager.stage % 2 == 0,
)
if num_microbatch_remaining > 0:
model_chunk_id = self.get_model_chunk_id(0, is_forward=False)
- output_obj_grad = self.recv_backward(model_chunk_id)
+ output_obj_grad, bwd_wait_handles = self.recv_backward(model_chunk_id)
# Run 1F1B in steady state.
for i in range(num_microbatch_remaining):
- last_iteration = i == num_microbatch_remaining - 1
+ fwd_batch_id = i + num_warmup_microbatch
+ last_batch = i == num_microbatch_remaining - 1
+ model_chunk_id = self.get_model_chunk_id(fwd_batch_id, is_forward=True)
- model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatch, is_forward=True)
+ # Wait for input.
+ _wait_p2p(fwd_wait_handles)
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
# Add input_obj and output_obj to end of list.
input_objs[model_chunk_id].append(input_obj)
@@ -473,64 +476,75 @@ class InterleavedSchedule(PipelineSchedule):
# Pop output_obj and output_obj from the start of the list for the backward pass.
_input_obj = input_objs[model_chunk_id].pop(0)
_output_obj = output_objs[model_chunk_id].pop(0)
- input_obj_grad = self.backward_step(optimizer, _input_obj, _output_obj, output_obj_grad)
- # NOTE: perform 2x communication for forward and backward
- def send_forward_recv_backward():
- if last_iteration and num_microbatch == num_microbatch_remaining:
- model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatch, is_forward=True)
- self.send_forward(model_chunk_id, output_obj)
+ # Helper functions
+ def send_forward_recv_forward():
+ if last_batch:
+ model_chunk_id = self.get_model_chunk_id(fwd_batch_id, is_forward=True)
+ wait_handles = self.send_forward(model_chunk_id, output_obj)
+ return None, wait_handles
else:
- output_obj_grad = self.send_forward_recv_backward(
- model_chunk_id_send=self.get_model_chunk_id(i + num_warmup_microbatch, is_forward=True),
- model_chunk_id_recv=self.get_model_chunk_id(i + 1, is_forward=False),
+ input_obj, wait_handles = self.send_forward_recv_forward(
+ model_chunk_id_send=self.get_model_chunk_id(fwd_batch_id, is_forward=True),
+ model_chunk_id_recv=self.get_model_chunk_id(fwd_batch_id + 1, is_forward=True),
output_tensor=output_obj,
- send_prior_fallback=self.stage_manager.stage % 2 == 0,
+ send_first=self.stage_manager.stage % 2 == 0
+ and i > 0, # Receive from warmup stage first in the first batch
)
- return output_obj_grad
+ return input_obj, wait_handles
- def send_backward_recv_forward():
- if last_iteration:
+ def send_backward_recv_backward():
+ no_cooldown = num_microbatch == num_microbatch_remaining
+ if last_batch and no_cooldown:
model_chunk_id = self.get_model_chunk_id(i, is_forward=False)
- self.send_backward(model_chunk_id, input_obj_grad)
+ wait_handles = self.send_backward(model_chunk_id, input_obj_grad)
+ return None, wait_handles
else:
- input_obj = self.send_backward_recv_forward(
+ output_obj_grad, wait_handles = self.send_backward_recv_backward(
model_chunk_id_send=self.get_model_chunk_id(i, is_forward=False),
- model_chunk_id_recv=self.get_model_chunk_id(i + num_warmup_microbatch + 1, is_forward=True),
+ model_chunk_id_recv=self.get_model_chunk_id(i + 1, is_forward=False),
input_tensor_grad=input_obj_grad,
- send_prior_fallback=self.stage_manager.stage % 2 == 0 and i > 0,
+ send_first=self.stage_manager.stage % 2 == 0,
)
- return input_obj
+ return output_obj_grad, wait_handles
- if self.stage_manager.stage % 2 == 0:
- output_obj_grad = send_forward_recv_backward()
- input_obj = send_backward_recv_forward()
- else:
- input_obj = send_backward_recv_forward()
- output_obj_grad = send_forward_recv_backward()
+ input_obj, fwd_wait_handles = send_forward_recv_forward()
+ # Wait for upstream grad
+ _wait_p2p(bwd_wait_handles)
+ input_obj_grad = self.backward_step(optimizer, _input_obj, _output_obj, output_obj_grad)
+ # NOTE: It's documented by NCCL that running two concurrent communicators (batch_isend_irecv)
+ # risks deadlock (https://docs.nvidia.com/deeplearning/nccl/archives/nccl_2134/user-guide/docs/usage/communicators.html)
+ # however in practice this works fine, and Megatron does this too
+ # (https://github.com/microsoft/Megatron-DeepSpeed/blob/bcedecd1ff788d4d363f3365fd396053a08d65be/megatron/core/pipeline_parallel/schedules.py#L774)
+ # if deadlock, call _wait_p2p(fwd_wait_handles) here
+ output_obj_grad, bwd_wait_handles = send_backward_recv_backward()
if num_microbatch_remaining == 0:
model_chunk_id = self.get_model_chunk_id(0, is_forward=False)
- output_obj_grad = self.recv_backward(model_chunk_id)
+ output_obj_grad, bwd_wait_handles = self.recv_backward(model_chunk_id)
+
# Run cooldown backward passes.
for i in range(num_microbatch_remaining, num_microbatch):
- last_iteration = i == num_microbatch - 1
+ last_batch = i == num_microbatch - 1
model_chunk_id = self.get_model_chunk_id(i, is_forward=False)
_input_obj = input_objs[model_chunk_id].pop(0)
_output_obj = output_objs[model_chunk_id].pop(0)
- # output_obj_grad = self.recv_backward(model_chunk_id)
- input_obj_grad = self.backward_step(optimizer, _input_obj, _output_obj, output_obj_grad)
- if not last_iteration:
- output_obj_grad = self.send_backward_recv_backward(
+ # Wait for upstream grad
+ _wait_p2p(bwd_wait_handles)
+ # backward local grads
+ input_obj_grad = self.backward_step(optimizer, _input_obj, _output_obj, output_obj_grad)
+ if not last_batch:
+ output_obj_grad, bwd_wait_handles = self.send_backward_recv_backward(
model_chunk_id_send=self.get_model_chunk_id(i, is_forward=False),
model_chunk_id_recv=self.get_model_chunk_id(i + 1, is_forward=False),
input_tensor_grad=input_obj_grad,
- send_prior=self.stage_manager.stage % 2 == 0 and i > num_microbatch_remaining,
+ send_first=self.stage_manager.stage % 2 == 0 and i > num_microbatch_remaining,
)
+ assert (not self.overlap_p2p) or len(bwd_wait_handles) > 0
else:
model_chunk_id = self.get_model_chunk_id(i, is_forward=False)
- self.send_backward(model_chunk_id, input_obj_grad)
+ _ = self.send_backward(model_chunk_id, input_obj_grad)
assert all(len(v) == 0 for v in input_objs) and all(len(v) == 0 for v in output_objs)
diff --git a/colossalai/pipeline/schedule/one_f_one_b.py b/colossalai/pipeline/schedule/one_f_one_b.py
index bfea8b67d..7f0d0e349 100644
--- a/colossalai/pipeline/schedule/one_f_one_b.py
+++ b/colossalai/pipeline/schedule/one_f_one_b.py
@@ -45,7 +45,8 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
num_microbatches is not None or microbatch_size is not None
), "Either num_microbatches or microbatch_size should be provided"
- self.comm = PipelineP2PCommunication(stage_manager)
+ self.comm = PipelineP2PCommunication(stage_manager, overlap_p2p=False)
+
self.num_microbatches = num_microbatches
self.microbatch_size = microbatch_size
self.batch: Optional[Any] = None
@@ -124,7 +125,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
Any: The input tensor or input tensor list.
"""
if not self.stage_manager.is_first_stage():
- input_tensor = self.comm.recv_forward(prev_rank, metadata_recv=self.tensor_metadata_recv)
+ input_tensor, _ = self.comm.recv_forward(prev_rank, metadata_recv=self.tensor_metadata_recv)
if self.enable_metadata_cache and self.tensor_metadata_recv is None:
self.tensor_metadata_recv = create_send_metadata(input_tensor)
@@ -141,7 +142,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
Any: The input gradient tensor or gradient tensor list.
"""
if not self.stage_manager.is_last_stage():
- output_tensor_grad = self.comm.recv_backward(next_rank, metadata_recv=self.grad_metadata_recv)
+ output_tensor_grad, _ = self.comm.recv_backward(next_rank, metadata_recv=self.grad_metadata_recv)
if self.enable_metadata_cache and self.grad_metadata_recv is None:
self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
@@ -171,9 +172,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
self.comm.send_backward(input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata)
self.send_grad_metadata = not self.enable_metadata_cache
- def send_forward_recv_backward(
- self, output_tensor: Any, next_rank: int = None, send_prior_fallback: Optional[bool] = None
- ) -> Any:
+ def send_forward_recv_backward(self, output_tensor: Any, send_first: Optional[bool] = None) -> Any:
"""Sends the input tensor to the next stage and copy the gradient tensor from the next stage in pipeline.
For 1F1B.
@@ -183,13 +182,12 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
"""
if not self.stage_manager.is_last_stage():
if not self.send_tensor_metadata and self.grad_metadata_recv is not None:
- send_prior_fallback = None # must not fallback
- output_tensor_grad = self.comm.send_forward_recv_backward(
+ send_first = None
+ output_tensor_grad, _ = self.comm.send_forward_recv_backward(
output_tensor,
- next_rank,
send_metadata=self.send_tensor_metadata,
metadata_recv=self.grad_metadata_recv,
- send_prior_fallback=send_prior_fallback,
+ send_first=send_first,
)
self.send_tensor_metadata = not self.enable_metadata_cache
if self.enable_metadata_cache and self.grad_metadata_recv is None:
@@ -197,9 +195,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
return output_tensor_grad
- def send_backward_recv_forward(
- self, input_tensor_grad: Any, prev_rank: int = None, send_prior_fallback: Optional[bool] = None
- ) -> Any:
+ def send_backward_recv_forward(self, input_tensor_grad: Any, send_first: Optional[bool] = None) -> Any:
"""Sends the gradient tensor to the previous stage and copy the input tensor from the previous stage in pipeline.
For 1F1B.
@@ -209,13 +205,12 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
"""
if not self.stage_manager.is_first_stage():
if not self.send_grad_metadata and self.tensor_metadata_recv is not None:
- send_prior_fallback = None # must not fallback
- input_tensor = self.comm.send_backward_recv_forward(
+ send_first = None # must not fallback
+ input_tensor, _ = self.comm.send_backward_recv_forward(
input_tensor_grad,
- prev_rank,
send_metadata=self.send_grad_metadata,
metadata_recv=self.tensor_metadata_recv,
- send_prior_fallback=send_prior_fallback,
+ send_first=send_first,
)
self.send_grad_metadata = not self.enable_metadata_cache
if self.enable_metadata_cache and self.tensor_metadata_recv is None:
@@ -381,9 +376,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
last_iteration = i == (num_microbatches_remaining - 1)
output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs)
- output_obj_grad = self.send_forward_recv_backward(
- output_obj, send_prior_fallback=self.stage_manager.stage % 2 == 0
- )
+ output_obj_grad = self.send_forward_recv_backward(output_obj, send_first=self.stage_manager.stage % 2 == 0)
# Add input_obj and output_obj to end of list.
input_objs.append(input_obj)
output_objs.append(output_obj)
@@ -398,7 +391,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
self.send_backward(input_obj_grad)
else:
input_obj = self.send_backward_recv_forward(
- input_obj_grad, send_prior_fallback=self.stage_manager.stage % 2 == 0
+ input_obj_grad, send_first=self.stage_manager.stage % 2 == 0
)
# Run cooldown backward passes.
diff --git a/colossalai/pipeline/stage_manager.py b/colossalai/pipeline/stage_manager.py
index b7cbd67ab..354f110f0 100644
--- a/colossalai/pipeline/stage_manager.py
+++ b/colossalai/pipeline/stage_manager.py
@@ -35,7 +35,7 @@ class PipelineStageManager:
self.pipeline_axis = pipeline_axis
self.prev_rank: Optional[Tuple[int, ...]] = None
self.next_rank: Optional[Tuple[int, ...]] = None
- self.p2p_groups: Dict[Tuple[int, int], ProcessGroup] = {}
+ self.p2p_groups: Dict[Tuple[int, ...], ProcessGroup] = {}
if num_layers_per_stage is not None:
assert len(num_layers_per_stage) == self.num_stages
self.num_layers_per_stage = num_layers_per_stage
@@ -48,30 +48,14 @@ class PipelineStageManager:
# the next rank of the last rank is rank0
next_coord = coord[: self.pipeline_axis] + (coord[self.pipeline_axis] + 1,) + coord[self.pipeline_axis + 1 :]
self.next_rank = self.pg_mesh.ravel(next_coord, self.pg_mesh.shape, mode="wrap")
-
- # init p2p process groups
- stages = list(range(self.num_stages))
- for prev, cur in zip(stages[:-1], stages[1:]):
- group = self.pg_mesh.get_group_along_axis(self.pipeline_axis, [prev, cur])
- if self.stage in [prev, cur]:
- ranks_in_group = self.pg_mesh.get_ranks_in_group(group)
- self.p2p_groups[tuple(ranks_in_group)] = group
-
self.is_interleave = enable_interleave
# for interleaved pipeline parallel, each device is responsible for multiple chunk of layers
self.num_model_chunks: int = num_model_chunks
- if enable_interleave:
- # use circle p2p communication
- # add the process group of the first rank and the last rank
- group = self.pg_mesh.get_group_along_axis(self.pipeline_axis, [stages[0], stages[-1]])
- if self.stage in [stages[0], stages[-1]]:
- ranks_in_group = self.pg_mesh.get_ranks_in_group(group)
- self.p2p_groups[tuple(ranks_in_group)] = group
-
- # for shardformer, hold stage indices of model
- self.stage_indices: List[Tuple[int, int]]
- # for shardformer, hold model chunk id
- self.model_chunk_id: Optional[int] = None
+ # for shardformer, hold stage indices of model
+ self.stage_indices: List[Tuple[int, int]]
+ # for shardformer, hold model chunk id
+ self.model_chunk_id: Optional[int] = None
+ self.p2p_group = self.pg_mesh.get_group_along_axis(self.pipeline_axis)
def get_stage_index(
self,
@@ -184,19 +168,12 @@ class PipelineStageManager:
"""
return self.next_rank
- def get_p2p_process_group(self, first_rank: int, second_rank: int) -> ProcessGroup:
+ def get_p2p_process_group(self) -> ProcessGroup:
"""Get the p2p process group between two ranks. The order of the two ranks does not matter.
-
- Args:
- first_rank (int): The first rank.
- second_rank (int): The second rank.
-
Returns:
ProcessGroup: P2P process group between the two ranks.
"""
- if first_rank > second_rank:
- first_rank, second_rank = second_rank, first_rank
- return self.p2p_groups[(first_rank, second_rank)]
+ return self.p2p_group
def init_process_group_by_stages(self, stages: List[int]) -> ProcessGroup:
"""Get the process group of the given stages.
diff --git a/colossalai/shardformer/layer/qkv_fused_linear.py b/colossalai/shardformer/layer/qkv_fused_linear.py
index dc3634238..0f6595a7c 100644
--- a/colossalai/shardformer/layer/qkv_fused_linear.py
+++ b/colossalai/shardformer/layer/qkv_fused_linear.py
@@ -674,6 +674,8 @@ class FusedLinear1D_Col(ParallelModule):
process_group (`Union[ProcessGroup, List[ProcessGroup]]`): The process group to be used for weight sharding and communication.
n_fused (int): The number of layers to be fused. In common, Q,K,V are fused in one weight.
"""
+ LazyInitContext.materialize(module)
+
# get the attributes
in_features = module.in_features
out_features = module.out_features
diff --git a/colossalai/shardformer/modeling/t5.py b/colossalai/shardformer/modeling/t5.py
index b35bb6b94..1b5c03ce4 100644
--- a/colossalai/shardformer/modeling/t5.py
+++ b/colossalai/shardformer/modeling/t5.py
@@ -8,8 +8,15 @@ from transformers.modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
Seq2SeqLMOutput,
Seq2SeqModelOutput,
+ TokenClassifierOutput,
+)
+from transformers.models.t5.modeling_t5 import (
+ T5EncoderModel,
+ T5ForConditionalGeneration,
+ T5ForTokenClassification,
+ T5Model,
+ T5Stack,
)
-from transformers.models.t5.modeling_t5 import T5EncoderModel, T5ForConditionalGeneration, T5Model, T5Stack
from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager
@@ -582,6 +589,71 @@ class T5PipelineForwards:
return outputs
+ @staticmethod
+ def t5_for_token_classification_forward(
+ self: T5ForTokenClassification,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ stage_manager: Optional[PipelineStageManager] = None,
+ hidden_states: Optional[torch.FloatTensor] = None,
+ position_bias: Optional[torch.Tensor] = None,
+ encoder_decoder_position_bias: Optional[torch.Tensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ backward_tensor_keys: Optional[List[str]] = None,
+ stage_index: Optional[List[int]] = None,
+ decoder_starting_stage: Optional[int] = None,
+ ) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]:
+ r"""
+ This function is modified on the basis of transformers.models.t5.modeling_t5.T5ForTokenClassification.forward.
+ Please refer to original code of transformers for more details.
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = T5PipelineForwards.t5_stack_forward(
+ self.transformer.encoder,
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ inputs_embeds=inputs_embeds,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ stage_manager=stage_manager,
+ hidden_states=hidden_states,
+ position_bias=position_bias,
+ encoder_decoder_position_bias=encoder_decoder_position_bias,
+ stage_index=stage_index,
+ decoder_starting_stage=decoder_starting_stage,
+ )
+ if stage_manager.is_last_stage():
+ sequence_output = outputs[0]
+
+ sequence_output = self.dropout(sequence_output)
+ logits = self.classifier(sequence_output)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TokenClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ return outputs
+
def get_t5_flash_attention_forward():
from transformers.models.t5.modeling_t5 import T5Attention
diff --git a/colossalai/shardformer/policies/auto_policy.py b/colossalai/shardformer/policies/auto_policy.py
index 008dead6b..99b68aee2 100644
--- a/colossalai/shardformer/policies/auto_policy.py
+++ b/colossalai/shardformer/policies/auto_policy.py
@@ -68,6 +68,9 @@ _POLICY_LIST = {
file_name="t5", class_name="T5ForConditionalGenerationPolicy"
),
"transformers.models.t5.modeling_t5.T5EncoderModel": PolicyLocation(file_name="t5", class_name="T5EncoderPolicy"),
+ "transformers.models.t5.modeling_t5.T5ForTokenClassification": PolicyLocation(
+ file_name="t5", class_name="T5ForTokenClassificationPolicy"
+ ),
# GPT2
"transformers.models.gpt2.modeling_gpt2.GPT2Model": PolicyLocation(file_name="gpt2", class_name="GPT2ModelPolicy"),
"transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel": PolicyLocation(
diff --git a/colossalai/shardformer/policies/t5.py b/colossalai/shardformer/policies/t5.py
index 1298f0af3..0b594678c 100644
--- a/colossalai/shardformer/policies/t5.py
+++ b/colossalai/shardformer/policies/t5.py
@@ -31,7 +31,13 @@ from ..modeling.t5 import (
)
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
-__all__ = ["distribute_t5_layers", "T5ModelPolicy", "T5ForConditionalGenerationPolicy", "T5EncoderPolicy"]
+__all__ = [
+ "distribute_t5_layers",
+ "T5ModelPolicy",
+ "T5ForConditionalGenerationPolicy",
+ "T5EncoderPolicy",
+ "T5ForTokenClassificationPolicy",
+]
class T5BasePolicy(Policy):
@@ -312,9 +318,13 @@ class T5BasePolicy(Policy):
assert self.pipeline_stage_manager is not None
stage_manager = self.pipeline_stage_manager
- model = self.model
- encoder = self.model.encoder
- decoder = getattr(self.model, "decoder", None)
+ if self.model.__class__.__name__ == "T5ForTokenClassification":
+ model = self.model.transformer
+ else:
+ model = self.model
+
+ encoder = model.encoder
+ decoder = getattr(model, "decoder", None)
num_encoder_layers = len(encoder.block)
num_decoder_layers = len(decoder.block) if decoder else 0
@@ -353,7 +363,11 @@ class T5BasePolicy(Policy):
raise ValueError("set_pipeline_forward method can only be called when pipeline parallel is enabled.")
stage_manager = self.pipeline_stage_manager
- encoder = self.model.encoder
+ if self.model.__class__.__name__ == "T5ForTokenClassification":
+ encoder = self.model.transformer.encoder
+ else:
+ encoder = self.model.encoder
+
decoder = getattr(self.model, "decoder", None)
num_encoder_layers = len(encoder.block)
@@ -542,3 +556,46 @@ class T5EncoderPolicy(T5BasePolicy):
def get_shared_params(self) -> List[Dict[int, Tensor]]:
return []
+
+
+class T5ForTokenClassificationPolicy(T5EncoderPolicy):
+ def module_policy(self):
+ from transformers.models.t5.modeling_t5 import T5ForTokenClassification
+
+ policy = super().module_policy()
+
+ if self.shard_config.enable_tensor_parallelism:
+ addon_module = {
+ T5ForTokenClassification: ModulePolicyDescription(
+ sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="dropout",
+ target_module=DropoutForParallelInput,
+ )
+ ]
+ )
+ }
+ policy.update(addon_module)
+ if self.pipeline_stage_manager:
+ self.set_pipeline_forward(
+ model_cls=T5ForTokenClassification,
+ new_forward=T5PipelineForwards.t5_for_token_classification_forward,
+ policy=policy,
+ )
+
+ return policy
+
+ def get_held_layers(self) -> List[nn.Module]:
+ """
+ get pipeline layers for current stage
+ """
+ held_layers = super().get_held_layers()
+ stage_manager = self.pipeline_stage_manager
+ if stage_manager.is_last_stage(ignore_chunk=True):
+ held_layers.append(self.model.dropout)
+ held_layers.append(self.model.classifier)
+ return held_layers
+
+ def get_shared_params(self) -> List[Dict[int, Tensor]]:
+ # no shared params for sequence classification model
+ return []
diff --git a/colossalai/zero/gemini/chunk/chunk.py b/colossalai/zero/gemini/chunk/chunk.py
index 18fbf8fc3..969df9621 100644
--- a/colossalai/zero/gemini/chunk/chunk.py
+++ b/colossalai/zero/gemini/chunk/chunk.py
@@ -403,9 +403,9 @@ class Chunk:
self.shard_size, dtype=self.dtype, device=get_accelerator().get_current_device()
)
- input_list = list(torch.chunk(self.cuda_global_chunk, chunks=self.pg_size, dim=0))
- self.grad_reduce_work = dist.reduce_scatter(
- self.cuda_shard, input_list, group=self.torch_pg, async_op=async_op
+ assert self.cuda_global_chunk.is_contiguous()
+ self.grad_reduce_work = dist.reduce_scatter_tensor(
+ self.cuda_shard, self.cuda_global_chunk, group=self.torch_pg, async_op=async_op
)
if self.extra_dp_group is not None:
@@ -520,8 +520,10 @@ class Chunk:
assert self.cuda_shard is not None
alloc_storage(self.cuda_global_chunk)
- gather_list = list(torch.chunk(input=self.cuda_global_chunk, chunks=self.pg_size, dim=0))
- work = dist.all_gather(gather_list, self.cuda_shard, self.torch_pg, async_op=async_op)
+ assert self.cuda_global_chunk.is_contiguous()
+ work = dist.all_gather_into_tensor(
+ self.cuda_global_chunk, self.cuda_shard, self.torch_pg, async_op=async_op
+ )
self.cuda_shard = None
self.is_gathered = True
diff --git a/colossalai/zero/gemini/chunk/manager.py b/colossalai/zero/gemini/chunk/manager.py
index 3a5f0a5aa..d0e1755f4 100644
--- a/colossalai/zero/gemini/chunk/manager.py
+++ b/colossalai/zero/gemini/chunk/manager.py
@@ -133,12 +133,12 @@ class ChunkManager:
self.__sub_accessed_chunk(chunk)
self.__add_memory_usage(chunk.memory_usage)
- def move_chunk(self, chunk: Chunk, device: torch.device, force_copy: bool = False) -> None:
+ def move_chunk(self, chunk: Chunk, device: torch.device, force_copy: bool = False, async_move=False) -> None:
"""Move the shard of the chunk to the target device."""
if not chunk.can_move or chunk.device_type == device.type:
return
self.__sub_memory_usage(chunk.memory_usage)
- chunk.shard_move(device, force_copy)
+ chunk.shard_move(device, force_copy, non_blocking=async_move)
self.__add_memory_usage(chunk.memory_usage)
def trans_tensor_state(self, tensor: torch.Tensor, state: TensorState) -> None:
diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py
index ebdde83b4..80b2c7961 100644
--- a/colossalai/zero/gemini/gemini_ddp.py
+++ b/colossalai/zero/gemini/gemini_ddp.py
@@ -387,6 +387,7 @@ class GeminiDDP(ModelWrapper):
p: nn.Parameter,
async_reduce_stream: Optional[torch.cuda.Stream] = None,
):
+ async_reduce_scatter = async_reduce_stream is not None
setattr(p, "_gemini_reduced", True)
empty_grad = torch.empty_like(grad)
free_storage(empty_grad)
@@ -426,7 +427,7 @@ class GeminiDDP(ModelWrapper):
async_reduce_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(async_reduce_stream):
- reduced = chunk_manager.reduce_chunk(grad_chunk, async_op=(async_reduce_stream is not None))
+ reduced = chunk_manager.reduce_chunk(grad_chunk, async_op=async_reduce_scatter)
if reduced:
grad_chunk.wait_async_reduce()
if not chunk_manager.reuse_fp16_chunk:
@@ -447,9 +448,13 @@ class GeminiDDP(ModelWrapper):
# record l2 norm for gradient clipping. flag is bound to fp16 chunk
if chunk.l2_norm_flag:
grad_chunk.set_l2_norm()
- chunk_manager.move_chunk(grad_chunk, grads_device[p], force_copy=True)
+ chunk_manager.move_chunk(
+ grad_chunk, grads_device[p], force_copy=True, async_move=async_reduce_scatter
+ )
if not (master_weights) or (enable_gradient_accumulation):
- chunk_manager.move_chunk(chunk, grads_device[p], force_copy=True)
+ chunk_manager.move_chunk(
+ chunk, grads_device[p], force_copy=True, async_move=async_reduce_scatter
+ )
return empty_grad
def zero_grad(self, set_to_none: bool = False) -> None:
diff --git a/colossalai/zero/low_level/bookkeeping/tensor_bucket.py b/colossalai/zero/low_level/bookkeeping/tensor_bucket.py
index 16ba8a6d6..5b09019b9 100644
--- a/colossalai/zero/low_level/bookkeeping/tensor_bucket.py
+++ b/colossalai/zero/low_level/bookkeeping/tensor_bucket.py
@@ -1,3 +1,7 @@
+from typing import Optional
+
+import torch
+import torch.distributed as dist
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
@@ -6,6 +10,7 @@ class TensorBucket:
self._max_size = size
self._current_size = 0
self._bucket = []
+ self._write_back_pairs = {}
@property
def max_size(self):
@@ -21,7 +26,7 @@ class TensorBucket:
def is_empty(self):
return len(self._bucket) == 0
- def add_to_bucket(self, tensor, allow_oversize=False):
+ def add_to_bucket(self, tensor, allow_oversize=False, write_back_tensor: Optional[torch.Tensor] = None):
tensor_size = tensor.numel()
if not allow_oversize and self.will_exceed_max_size(tensor_size):
@@ -30,6 +35,8 @@ class TensorBucket:
self._bucket.append(tensor)
self._current_size += tensor_size
+ write_back_tensor = write_back_tensor if write_back_tensor is not None else tensor
+ self._write_back_pairs[tensor] = write_back_tensor
def will_exceed_max_size(self, tensor_size):
expected_size = self._current_size + tensor_size
@@ -40,12 +47,30 @@ class TensorBucket:
def empty(self):
self._bucket = []
- self._size = 0
+ self._current_size = 0
+ self._write_back_pairs = {}
def flatten(self):
return _flatten_dense_tensors(self._bucket)
+ def unflatten(self, flat_tensor):
+ return _unflatten_dense_tensors(flat_tensor, self._bucket)
+
def unflatten_and_copy(self, flat_tensor):
- unflattened_tensor_list = _unflatten_dense_tensors(flat_tensor, self._bucket)
+ unflattened_tensor_list = self.unflatten(flat_tensor)
for old, new in zip(self._bucket, unflattened_tensor_list):
old.copy_(new)
+
+ def all_gather(self, group=None):
+ flat = self.flatten()
+ buffers = [torch.empty_like(flat) for _ in range(dist.get_world_size(group))]
+ dist.all_gather(buffers, flat, group=group)
+ unflat_buffers = [self.unflatten(buffer) for buffer in buffers]
+ # transpose the list of list
+ unflat_buffers = list(map(list, zip(*unflat_buffers)))
+ for unflat_shards, tensor in zip(unflat_buffers, self._bucket):
+ write_back_tensor = self._write_back_pairs[tensor]
+ write_back_tensor.data.copy_(
+ _flatten_dense_tensors(unflat_shards)[: write_back_tensor.numel()].reshape_as(write_back_tensor)
+ )
+ self.empty()
diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py
index 5f7f2a4e2..d19e0a002 100644
--- a/colossalai/zero/low_level/low_level_optim.py
+++ b/colossalai/zero/low_level/low_level_optim.py
@@ -23,7 +23,7 @@ from colossalai.logging import get_dist_logger
from colossalai.tensor.moe_tensor.api import is_moe_tensor
from ._utils import calculate_global_norm_from_list, flatten, has_inf_or_nan, release_param_grad, sync_tensor
-from .bookkeeping import BucketStore, GradientStore, ParameterStore
+from .bookkeeping import BucketStore, GradientStore, ParameterStore, TensorBucket
class LowLevelZeroFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):
@@ -694,34 +694,33 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
for group_id in range(self.num_param_groups):
release_param_grad(self._master_param_groups_of_current_rank[group_id])
+ tensor_bucket = TensorBucket(self._bucket_store.reduce_bucket_size)
+ moe_tensor_bucket = TensorBucket(self._bucket_store.reduce_bucket_size)
+
# update working partition updated by the current rank
device = get_accelerator().get_current_device()
for group_id in range(self.num_param_groups):
master_working_param = self.optim.param_groups[group_id]["params"]
for idx, splited_param in enumerate(master_working_param):
working_param = real_working_params[group_id][idx]
+ param_to_gather = splited_param.to(device).to(self._dtype)
if self._bucket_store.moe_extra_dp_pg is not None and is_moe_tensor(working_param):
- all_splited_param = [
- torch.zeros(splited_param.shape, device=device, dtype=self._dtype)
- for _ in range(self._bucket_store.moe_extra_dp_pg_size)
- ]
- dist.all_gather(
- all_splited_param,
- splited_param.to(device).to(self._dtype),
- group=self._bucket_store.moe_extra_dp_pg,
- )
+ try:
+ moe_tensor_bucket.add_to_bucket(param_to_gather, write_back_tensor=working_param)
+ except RuntimeError:
+ moe_tensor_bucket.all_gather(self._bucket_store.moe_extra_dp_pg)
+ moe_tensor_bucket.add_to_bucket(param_to_gather, write_back_tensor=working_param)
else:
- all_splited_param = [
- torch.zeros(splited_param.shape, device=device, dtype=self._dtype)
- for _ in range(self._bucket_store.zero_world_size)
- ]
- dist.all_gather(
- all_splited_param,
- splited_param.to(device).to(self._dtype),
- group=self._bucket_store.torch_pg,
- )
- working_param.data.copy_(flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param))
+ try:
+ tensor_bucket.add_to_bucket(param_to_gather, write_back_tensor=working_param)
+ except RuntimeError:
+ tensor_bucket.all_gather(self._bucket_store.moe_extra_dp_pg)
+ tensor_bucket.add_to_bucket(param_to_gather, write_back_tensor=working_param)
self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id]
+ if not moe_tensor_bucket.is_empty():
+ moe_tensor_bucket.all_gather(self._bucket_store.moe_extra_dp_pg)
+ if not tensor_bucket.is_empty():
+ tensor_bucket.all_gather(self._bucket_store.torch_pg)
def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> float:
r"""
diff --git a/docs/README-zh-Hans.md b/docs/README-zh-Hans.md
index dc97b461a..ef121d348 100644
--- a/docs/README-zh-Hans.md
+++ b/docs/README-zh-Hans.md
@@ -9,6 +9,7 @@
文档 |
例程 |
论坛 |
+ 潞晨云 |
博客
[data:image/s3,"s3://crabby-images/0b579/0b579880349b54d34ed6c79ebcb618fb303a421a" alt="GitHub Repo stars"](https://github.com/hpcaitech/ColossalAI/stargazers)
@@ -127,6 +128,8 @@ Colossal-AI 为您提供了一系列并行组件。我们的目标是让您的
[[博客]](https://hpc-ai.com/blog/open-sora-from-hpc-ai-tech-team-continues-open-source-generate-any-16-second-720p-hd-video-with-one-click-model-weights-ready-to-use)
[[模型权重]](https://github.com/hpcaitech/Open-Sora?tab=readme-ov-file#model-weights)
[[演示样例]](https://github.com/hpcaitech/Open-Sora?tab=readme-ov-file#-latest-demo)
+[[潞晨云]](https://cloud.luchentech.com/)
+[[OpenSora镜像]](https://cloud.luchentech.com/doc/docs/image/open-sora/)
### Colossal-LLaMA-2
+[[潞晨云]](https://cloud.luchentech.com/)
+[[LLaMA3 镜像]](https://cloud.luchentech.com/doc/docs/image/llama)
- 7B:千元预算半天训练,效果媲美主流大模型,开源可商用中文LLaMA-2
[[代码]](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Colossal-LLaMA-2)
@@ -265,7 +270,9 @@ Colossal-AI 为您提供了一系列并行组件。我们的目标是让您的
- 700亿参数LLaMA3训练加速18%
-[[code]](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/llama)
+[[代码]](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/llama)
+[[潞晨云]](https://cloud.luchentech.com/)
+[[LLaMA3 镜像]](https://cloud.luchentech.com/doc/docs/image/llama)
### LLaMA2
@@ -378,6 +385,8 @@ Colossal-AI 为您提供了一系列并行组件。我们的目标是让您的
- AI大模型推理速度部分接近翻倍,与vLLM的离线推理性能相比
[[代码]](https://github.com/hpcaitech/ColossalAI/tree/main/colossalai/inference)
[[博客]](https://hpc-ai.com/blog/colossal-inference)
+[[潞晨云]](https://cloud.luchentech.com/)
+[[LLaMA3 镜像]](https://cloud.luchentech.com/doc/docs/image/llama)
### Grok-1
diff --git a/examples/README.md b/examples/README.md
index b822fb8ff..045632fc3 100644
--- a/examples/README.md
+++ b/examples/README.md
@@ -1,4 +1,12 @@
# Colossal-AI Examples
+
## Table of Contents
diff --git a/examples/language/llama/benchmark.py b/examples/language/llama/benchmark.py
index f6c975305..8a35db1f7 100644
--- a/examples/language/llama/benchmark.py
+++ b/examples/language/llama/benchmark.py
@@ -1,9 +1,11 @@
import argparse
import resource
import time
+import warnings
from contextlib import nullcontext
import torch
+import torch.distributed as dist
from data_utils import RandomDataset
from model_utils import format_numel_str, get_model_numel
from performance_evaluator import PerformanceEvaluator, get_profile_context
@@ -21,11 +23,19 @@ from colossalai.lazy import LazyInitContext
from colossalai.nn.optimizer import HybridAdam
from colossalai.shardformer import PipelineGradientCheckpointConfig
+warnings.filterwarnings("ignore")
# ==============================
# Constants
# ==============================
MODEL_CONFIGS = {
+ "100m": LlamaConfig(
+ max_position_embeddings=4096,
+ num_hidden_layers=4,
+ num_attention_heads=32,
+ intermediate_size=2048,
+ hidden_size=1024,
+ ),
"7b": LlamaConfig(max_position_embeddings=4096),
"13b": LlamaConfig(
hidden_size=5120,
@@ -58,6 +68,9 @@ def main():
default="gemini",
help="Choose which plugin to use",
)
+ parser.add_argument(
+ "--overlap", action="store_true", help="Overlap communication with computation in Pipeline Parallel."
+ )
parser.add_argument("-b", "--batch_size", type=int, default=2, help="Batch size")
parser.add_argument("-s", "--num_steps", type=int, default=5, help="Number of steps to run")
parser.add_argument("-i", "--ignore_steps", type=int, default=2, help="Number of steps to ignore")
@@ -78,11 +91,13 @@ def main():
parser.add_argument("--mbs", type=int, default=1, help="Micro batch size of pipeline parallel")
parser.add_argument("--zero", type=int, default=0, help="Zero Stage when hybrid plugin is enabled")
parser.add_argument("--custom-ckpt", action="store_true", help="Customize checkpoint", default=False)
- parser.add_argument("--profile", action="store_true", help="Enable profiling", default=False)
- parser.add_argument(
- "--disable-async-reduce", action="store_true", help="Disable the asynchronous reduce operation", default=False
- )
+
+ parser.add_argument("--pp_style", default="1f1b", choices=["1f1b", "interleaved"])
+ parser.add_argument("--n_chunks", default=1, help="number of model chunks", type=eval)
+ parser.add_argument("--profile", action="store_true", help="Profile the code", default=False)
+ parser.add_argument("--disable-async-reduce", action="store_true", help="Disable the asynchronous reduce operation")
parser.add_argument("--prefetch_num", type=int, default=0, help="chunk prefetch max number")
+ parser.add_argument("--no_cache", action="store_true")
args = parser.parse_args()
colossalai.launch_from_torch()
@@ -98,6 +113,7 @@ def main():
num_ckpt_layers_per_stage=[19, 19, 19, 13],
),
"num_layers_per_stage": [19, 20, 20, 21],
+ "pp_style": "interleaved",
}
if args.custom_ckpt
else {}
@@ -174,6 +190,8 @@ def main():
plugin = HybridParallelPlugin(
tp_size=args.tp,
pp_size=args.pp,
+ pp_style=args.pp_style,
+ num_model_chunks=args.n_chunks,
zero_stage=args.zero,
sp_size=args.sp,
enable_sequence_parallelism=args.sp > 1,
@@ -182,12 +200,16 @@ def main():
microbatch_size=args.mbs,
precision="bf16",
dp_outside=False,
+ overlap_p2p=args.overlap,
+ enable_metadata_cache=not args.no_cache,
**hybrid_kwargs,
)
elif args.plugin == "3d_cpu":
plugin = HybridParallelPlugin(
tp_size=args.tp,
pp_size=args.pp,
+ pp_style=args.pp_style,
+ num_model_chunks=args.n_chunks,
zero_stage=args.zero,
cpu_offload=True,
enable_fused_normalization=torch.cuda.is_available(),
@@ -195,6 +217,7 @@ def main():
microbatch_size=args.mbs,
initial_scale=2**8,
precision="bf16",
+ overlap_p2p=args.overlap,
)
else:
raise ValueError(f"Unknown plugin {args.plugin}")
@@ -210,10 +233,11 @@ def main():
config = MODEL_CONFIGS[args.config]
else:
config = AutoConfig.from_pretrained(args.config, trust_remote_code=True)
+ torch.cuda.manual_seed(42)
dataset = RandomDataset(
num_samples=args.batch_size * args.num_steps * dp_size, max_length=args.max_length, vocab_size=config.vocab_size
)
- dataloader = plugin.prepare_dataloader(dataset, batch_size=args.batch_size, shuffle=True, drop_last=True)
+ dataloader = plugin.prepare_dataloader(dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, seed=42)
# ==============================
# Initialize Model and Optimizer
@@ -229,8 +253,13 @@ def main():
init_kwargs["empty_init"] = False
with init_ctx:
- model = AutoModelForCausalLM.from_config(config, trust_remote_code=True, **init_kwargs)
-
+ model = AutoModelForCausalLM.from_config(
+ config,
+ trust_remote_code=True,
+ **init_kwargs,
+ attn_implementation="flash_attention_2",
+ torch_dtype=torch.bfloat16,
+ )
if args.grad_checkpoint:
model.gradient_checkpointing_enable()
if config.model_type == "chatglm":
@@ -251,6 +280,7 @@ def main():
optimizer = HybridAdam(model.parameters())
torch.set_default_dtype(torch.bfloat16)
model, optimizer, _, dataloader, _ = booster.boost(model, optimizer, dataloader=dataloader)
+
torch.set_default_dtype(torch.float)
coordinator.print_on_master(
f"Booster init max CUDA memory: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB"
@@ -261,7 +291,7 @@ def main():
with get_profile_context(
args.profile,
- 1,
+ args.ignore_steps,
len(dataloader) - 1,
save_dir=f"profile/{time.strftime('%H:%M', time.localtime())}-{args.plugin}-llama-{args.config}",
) as prof:
@@ -269,15 +299,19 @@ def main():
data_iter = iter(dataloader)
for step in tqdm(range(len(dataloader)), desc="Step", disable=not coordinator.is_master()):
performance_evaluator.on_step_start(step)
- booster.execute_pipeline(
+ outputs = booster.execute_pipeline(
data_iter,
model,
criterion=lambda outputs, inputs: outputs[0],
optimizer=optimizer,
- return_loss=False,
+ return_loss=True,
)
+ loss = outputs["loss"]
+ if dist.get_rank() == dist.get_world_size() - 1:
+ print(f"Step {step} loss: {loss}")
optimizer.step()
optimizer.zero_grad()
+
performance_evaluator.on_step_end(input_ids=torch.empty(args.batch_size, args.max_length))
prof.step()
else:
@@ -288,6 +322,7 @@ def main():
booster.backward(loss, optimizer)
optimizer.step()
optimizer.zero_grad()
+
performance_evaluator.on_step_end(**batch)
prof.step()
diff --git a/tests/kit/model_zoo/transformers/t5.py b/tests/kit/model_zoo/transformers/t5.py
index 2ccfb0356..f6ccb297e 100644
--- a/tests/kit/model_zoo/transformers/t5.py
+++ b/tests/kit/model_zoo/transformers/t5.py
@@ -40,6 +40,14 @@ def data_gen_for_t5_model():
return data
+def data_gen_for_token_classification():
+ # token classification data gen
+ # `labels` is the type not the token id for token classification, 0 or 1
+ data = data_gen_for_encoder_only()
+ data["labels"] = torch.tensor([[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], dtype=torch.int64)
+ return data
+
+
# output transform function
output_transform_fn = lambda x: x
@@ -47,6 +55,7 @@ output_transform_fn = lambda x: x
loss_fn_for_t5_model = lambda x: x["last_hidden_state"].mean()
loss_fn_for_encoder_only = lambda x: x["last_hidden_state"].mean()
loss_fn_for_conditional_generation = lambda x: x["loss"]
+loss_fn_for_token_classification = lambda x: x["loss"]
# define model config
config = transformers.T5Config(d_model=128, num_layers=2, dropout_rate=0, decoder_start_token_id=0)
@@ -79,3 +88,11 @@ model_zoo.register(
loss_fn=loss_fn_for_encoder_only,
model_attribute=ModelAttribute(has_control_flow=True),
)
+model_zoo.register(
+ name="transformers_t5_for_token_classification",
+ model_fn=lambda: transformers.T5ForTokenClassification(config),
+ data_gen_fn=data_gen_for_token_classification,
+ output_transform_fn=output_transform_fn,
+ loss_fn=loss_fn_for_token_classification,
+ model_attribute=ModelAttribute(has_control_flow=True),
+)
diff --git a/tests/test_pipeline/test_p2p_communication.py b/tests/test_pipeline/test_p2p_communication.py
index 48a8d12e0..30b557f5e 100644
--- a/tests/test_pipeline/test_p2p_communication.py
+++ b/tests/test_pipeline/test_p2p_communication.py
@@ -15,8 +15,7 @@ WORLD_SIZE = 2
def check_p2p_communication():
pg_mesh = ProcessGroupMesh(WORLD_SIZE)
stage_manager = PipelineStageManager(pg_mesh, 0)
- p2p = PipelineP2PCommunication(stage_manager)
-
+ p2p = PipelineP2PCommunication(stage_manager, overlap_p2p=False)
rank = dist.get_rank()
tensor = torch.ones(1, device=get_accelerator().get_current_device())
@@ -31,41 +30,40 @@ def check_p2p_communication():
for obj in data:
p2p.send_forward(obj)
for i in range(len(data)):
- recv_obj = p2p.send_forward_recv_backward(data[i], send_prior_fallback=False)
+ recv_obj, _ = p2p.send_forward_recv_backward(data[i], send_first=False)
assert recv_obj == data[-(i + 1)]
elif rank == 1:
for obj in data:
- recv_obj = p2p.recv_forward()
+ recv_obj, _ = p2p.recv_forward()
assert recv_obj == obj
for i in range(len(data)):
p2p.send_backward(data[-(i + 1)])
- recv_obj = p2p.recv_forward()
+ recv_obj, _ = p2p.recv_forward()
assert recv_obj == data[i]
if rank == 1:
for obj in data:
p2p.send_backward(obj)
for i in range(len(data)):
- recv_obj = p2p.send_backward_recv_forward(data[i], send_prior_fallback=True)
+ recv_obj, _ = p2p.send_backward_recv_forward(data[i], send_first=True)
assert recv_obj == data[-(i + 1)]
elif rank == 0:
for obj in data:
- recv_obj = p2p.recv_backward()
+ recv_obj, _ = p2p.recv_backward()
assert recv_obj == obj
for i in range(len(data)):
- recv_obj = p2p.recv_backward()
- p2p.send_forward(data[-(i + 1)])
+ recv_obj, _ = p2p.send_forward_recv_backward(data[-(i + 1)], send_first=False)
assert recv_obj == data[i]
if rank == 0:
- recv_obj = p2p.send_forward_recv_backward(
+ recv_obj, _ = p2p.send_forward_recv_backward(
tensor,
send_metadata=False,
metadata_recv=create_send_metadata(tensor),
)
assert recv_obj == tensor
elif rank == 1:
- recv_obj = p2p.recv_forward(metadata_recv=create_send_metadata(tensor))
+ recv_obj, _ = p2p.recv_forward(metadata_recv=create_send_metadata(tensor))
assert recv_obj == tensor
p2p.send_backward(tensor, send_metadata=False)
diff --git a/tests/test_pipeline/test_stage_manager.py b/tests/test_pipeline/test_stage_manager.py
index 5146a86c8..a3793013b 100644
--- a/tests/test_pipeline/test_stage_manager.py
+++ b/tests/test_pipeline/test_stage_manager.py
@@ -52,7 +52,7 @@ def check_stage_manager():
# check p2p groups
for prev, cur in zip(ranks_in_group[:-1], ranks_in_group[1:]):
if rank in [prev, cur]:
- group = stage_manager.get_p2p_process_group(prev, cur)
+ group = stage_manager.get_p2p_process_group()
dist.barrier(group=group)
# check stage groups
diff --git a/tests/test_shardformer/test_model/test_shard_t5.py b/tests/test_shardformer/test_model/test_shard_t5.py
index 521dc9130..6cdf5bf41 100644
--- a/tests/test_shardformer/test_model/test_shard_t5.py
+++ b/tests/test_shardformer/test_model/test_shard_t5.py
@@ -41,14 +41,17 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
t5 = unwrap_model(org_model)
sharded_t5 = unwrap_model(sharded_model)
- row_layer_for_check = ["shared", "encoder.block[0].layer[0].SelfAttention.q"]
+ if t5.__class__.__name__ == "T5ForTokenClassification":
+ row_layer_for_check = ["transformer.shared", "transformer.encoder.block[0].layer[0].SelfAttention.q"]
+ else:
+ row_layer_for_check = ["shared", "encoder.block[0].layer[0].SelfAttention.q"]
# Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
grads_to_check = {}
if test_config["precision"] == "fp32":
atol, rtol = 1e-5, 1e-3
else:
- atol, rtol = 5e-3, 5e-3
+ atol, rtol = 5e-2, 5e-2
if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
row_layer_grads = get_grad_tensors_for_check(
t5, sharded_t5, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0
@@ -66,7 +69,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
else:
atol, rtol = 5e-3, 5e-3
- if org_model.__class__.__name__ != "T5ForConditionalGeneration":
+ if org_model.__class__.__name__ not in ["T5ForConditionalGeneration", "T5ForTokenClassification"]:
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
@@ -157,7 +160,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
)
@clear_cache_before_run()
def run_t5_test(test_config):
- sub_model_zoo = model_zoo.get_sub_registry("transformers_t5")
+ sub_model_zoo = model_zoo.get_sub_registry(["transformers_t5_for_token_classification"])
for name, (
model_fn,
@@ -167,7 +170,10 @@ def run_t5_test(test_config):
_,
) in sub_model_zoo.items():
# skip 4-stage pp test for t5_encoder
- if test_config["pp_size"] > 2 and name == "transformers_t5_encoder_model":
+ if test_config["pp_size"] > 2 and name in [
+ "transformers_t5_encoder_model",
+ "transformers_t5_for_token_classification",
+ ]:
continue
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
diff --git a/version.txt b/version.txt
index 940ac09aa..1d0ba9ea1 100644
--- a/version.txt
+++ b/version.txt
@@ -1 +1 @@
-0.3.9
+0.4.0