diff --git a/README.md b/README.md index 12d29727b..69506e338 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,7 @@ Documentation | Examples | Forum | + GPU Cloud Playground | Blog [![GitHub Repo stars](https://img.shields.io/github/stars/hpcaitech/ColossalAI?style=social)](https://github.com/hpcaitech/ColossalAI/stargazers) @@ -132,6 +133,8 @@ distributed training and inference in a few lines. [[blog]](https://hpc-ai.com/blog/open-sora-from-hpc-ai-tech-team-continues-open-source-generate-any-16-second-720p-hd-video-with-one-click-model-weights-ready-to-use) [[Model weights]](https://github.com/hpcaitech/Open-Sora?tab=readme-ov-file#model-weights) [[Demo]](https://github.com/hpcaitech/Open-Sora?tab=readme-ov-file#-latest-demo) +[[GPU Cloud Playground]](https://cloud.luchentech.com/) +[[OpenSora Image]](https://cloud.luchentech.com/doc/docs/image/open-sora/)
@@ -143,6 +146,9 @@ distributed training and inference in a few lines. ### Colossal-LLaMA-2 +[[GPU Cloud Playground]](https://cloud.luchentech.com/) +[[LLaMA3 Image]](https://cloud.luchentech.com/doc/docs/image/llama) + - 7B: One half-day of training using a few hundred dollars yields similar results to mainstream large models, open-source and commercial-free domain-specific LLM solution. [[code]](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Colossal-LLaMA-2) [[blog]](https://www.hpc-ai.tech/blog/one-half-day-of-training-using-a-few-hundred-dollars-yields-similar-results-to-mainstream-large-models-open-source-and-commercial-free-domain-specific-llm-solution) @@ -275,6 +281,8 @@ Acceleration of [AlphaFold Protein Structure](https://alphafold.ebi.ac.uk/) - 70 billion parameter LLaMA3 model training accelerated by 18% [[code]](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/llama) +[[GPU Cloud Playground]](https://cloud.luchentech.com/) +[[LLaMA3 Image]](https://cloud.luchentech.com/doc/docs/image/llama) ### LLaMA2

@@ -385,6 +393,8 @@ Please visit our [documentation](https://www.colossalai.org/) and [examples](htt - Large AI models inference speed doubled, compared to the offline inference performance of vLLM in some cases. [[code]](https://github.com/hpcaitech/ColossalAI/tree/main/colossalai/inference) [[blog]](https://hpc-ai.com/blog/colossal-inference) +[[GPU Cloud Playground]](https://cloud.luchentech.com/) +[[LLaMA3 Image]](https://cloud.luchentech.com/doc/docs/image/llama) ### Grok-1

diff --git a/applications/Colossal-LLaMA/README.md b/applications/Colossal-LLaMA/README.md index 93ba58ac5..5997008e8 100644 --- a/applications/Colossal-LLaMA/README.md +++ b/applications/Colossal-LLaMA/README.md @@ -2,6 +2,12 @@

Colossal-LLaMA

+ +

+ GPU Cloud Playground | + LLaMA3 Image +

+
## Table of Contents diff --git a/applications/ColossalEval/README.md b/applications/ColossalEval/README.md index a1a76f750..890b1fed3 100644 --- a/applications/ColossalEval/README.md +++ b/applications/ColossalEval/README.md @@ -2,6 +2,12 @@

+ +

+ GPU Cloud Playground | + Colossal-Eval Image +

+ ## Table of Contents diff --git a/applications/README.md b/applications/README.md index e7c23c7e9..5b8b5e501 100644 --- a/applications/README.md +++ b/applications/README.md @@ -2,6 +2,15 @@ This directory contains the applications that are powered by Colossal-AI. +
+ +

+ GPU Cloud Playground | + Playground Document +

+ +
+ The list of applications include: - [X] [Open-Sora](https://github.com/hpcaitech/Open-Sora): Revealing Complete Model Parameters, Training Details, and Everything for Sora-like Video Generation Models diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index 474b78aa2..ad131fbe7 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -369,9 +369,9 @@ class GeminiPlugin(DPPluginBase): assert precision in SUPPORTED_PRECISION, f"precision {precision} is not supported" if get_accelerator().name == "npu": assert placement_policy == "static", "NPU only supports static placement policy" - if placement_policy == "auto" and enable_async_reduce: + if enable_async_reduce and not pin_memory: logging.warning( - f"enable_async_reduce requires pin_memory to achieve best performance, which is not implicitly set." + f"enable_async_reduce sets pin_memory=True to achieve best performance, which is not implicitly set." ) pin_memory = True self.gemini_config = dict( diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index fa3c3646a..3bd43f172 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -946,7 +946,7 @@ class HybridParallelPlugin(PipelinePluginBase): gradient_checkpoint_config (GradientCheckpointConfig, optional): Configuration for gradient checkpointing. Defaults to None. enable_metadata_cache (bool, optional): Whether to enable metadata cache for pipeline parallelism. Defaults to True. make_vocab_size_divisible_by (int, optional): it's used when padding the vocabulary size, to make it choose an faster kenel. Default to 64. - + overlap_p2p (bool, optional): Whether to overlap the p2p communication in pipeline parallelism """ def __init__( @@ -992,6 +992,7 @@ class HybridParallelPlugin(PipelinePluginBase): enable_metadata_cache: bool = True, make_vocab_size_divisible_by: int = 64, dp_outside: bool = True, + overlap_p2p: bool = True, ) -> None: super().__init__() assert ( @@ -1062,7 +1063,9 @@ class HybridParallelPlugin(PipelinePluginBase): assert ( num_microbatches is not None or microbatch_size is not None ), "num_microbatches or microbatch_size must be specified when using pipeline parallelism" - assert self.zero_stage <= 1, "zero stage must be 0 or 1 when using pipeline parallelism" + assert ( + self.zero_stage <= 1 + ), "To avoid prohibitive gradient synchronization costs, zero stage must be 0 or 1 when using pipeline parallelism" self.stage_manager = PipelineStageManager( self.pg_mesh, pipeline_axis=self.pp_axis, @@ -1079,6 +1082,7 @@ class HybridParallelPlugin(PipelinePluginBase): num_microbatch=num_microbatches, microbatch_size=microbatch_size, enable_metadata_cache=enable_metadata_cache, + overlap_p2p=overlap_p2p, ) elif pp_style == "1f1b": self.schedule = OneForwardOneBackwardSchedule( diff --git a/colossalai/cluster/process_group_mesh.py b/colossalai/cluster/process_group_mesh.py index fea4a23ba..f0cb78c5f 100644 --- a/colossalai/cluster/process_group_mesh.py +++ b/colossalai/cluster/process_group_mesh.py @@ -134,7 +134,7 @@ class ProcessGroupMesh: """ assert mode in ["raise", "wrap", "clip"] - return np.ravel_multi_index(coord, shape, mode) + return int(np.ravel_multi_index(coord, shape, mode)) def get_group(self, ranks_in_group: List[int], backend: Optional[str] = None) -> ProcessGroup: """Get the process group with the given ranks. It the process group doesn't exist, it will be created. @@ -182,7 +182,7 @@ class ProcessGroupMesh: axis = [ axis, ] - assert isinstance(indices_at_axis[0], int) + assert isinstance(indices_at_axis[0], int), f"Expected int, but got {type(indices_at_axis[0])}." indices_at_axis = [ indices_at_axis, ] diff --git a/colossalai/inference/core/engine.py b/colossalai/inference/core/engine.py index a1b54fa1c..f0918c88c 100644 --- a/colossalai/inference/core/engine.py +++ b/colossalai/inference/core/engine.py @@ -24,8 +24,9 @@ from colossalai.inference.modeling.policy import model_policy_map from colossalai.inference.sampler import search_tokens from colossalai.inference.spec import Drafter, GlideInput from colossalai.inference.struct import Sequence -from colossalai.inference.utils import get_model_size +from colossalai.inference.utils import get_model_size, has_index_file from colossalai.interface import ModelWrapper +from colossalai.lazy import LazyInitContext from colossalai.logging import get_dist_logger from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer import ShardConfig, ShardFormer @@ -122,16 +123,24 @@ class InferenceEngine: model_inference_config: the configuration for modeling initialization when inference. model_shard_infer_config (ModelShardInferenceConfig): the configuration for init of module when inference. """ - + pretrained_path = None if isinstance(model_or_path, str): + import colossalai.interface.pretrained as pretrained_utils + try: - hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True) + hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True, torch_dtype=self.dtype) arch = getattr(hf_config, "architectures")[0] if arch in _supported_models.keys(): - # NOTE(lry89757) Currently we load the model using transformers-api, - # but we will use lazy tensor and checkpoint io to accelerate - # the model load process in the future. - model = _supported_models[arch].from_pretrained(model_or_path, trust_remote_code=True) + if arch is "BaichuanForCausalLM": + self.logger.warning( + "Attention ! We use lazy init by default, which could be faster for model loading. For baichuan model, the output maybe have a slight difference with transformers" + ) + ctx = LazyInitContext(default_device="cuda") + with ctx: + model = _supported_models[arch].from_pretrained( + model_or_path, trust_remote_code=True, torch_dtype=self.dtype + ) + pretrained_path = pretrained_utils.get_pretrained_path(model) else: # TODO(char-1ee): if the model not supported, use transformers APIs to load and generate raise ValueError(f"Model {arch} is not supported.") @@ -189,14 +198,13 @@ class InferenceEngine: f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}" ) - # NOTE(lry89757) Deprecated currently, will reused when introduce lazy tensor - # if isinstance(model_or_path, str) and not isinstance(casuallm, AutoModelForCausalLM): - # from colossalai.inference.core.plugin import InferCheckpoint_io + if pretrained_path: + from colossalai.inference.core.plugin import InferCheckpoint_io - # cpt_io = InferCheckpoint_io() - # if_has_index_file, model_index_file = has_index_file(model_or_path) - # assert if_has_index_file, "the model path is invalid" - # cpt_io.load_model(self.model, model_index_file) + cpt_io = InferCheckpoint_io() + if_has_index_file, model_index_file = has_index_file(pretrained_path) + assert if_has_index_file, "the model path is invalid" + cpt_io.load_model(self.model, model_index_file) free_gpu_memory, _ = torch.cuda.mem_get_info() peak_memory = init_gpu_memory - free_gpu_memory diff --git a/colossalai/inference/core/rpc_engine.py b/colossalai/inference/core/rpc_engine.py index 439c4b0b5..87222a744 100644 --- a/colossalai/inference/core/rpc_engine.py +++ b/colossalai/inference/core/rpc_engine.py @@ -73,7 +73,9 @@ class RPCInferenceEngine(InferenceEngine): try: if isinstance(model_or_path, str): - self.model_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True) + self.model_config = AutoConfig.from_pretrained( + model_or_path, trust_remote_code=True, torch_dtype=self.dtype + ) elif isinstance(model_or_path, nn.Module): self.logger.error( f"An exception occurred during loading model Config: For {__class__.__name__}, we don't support param like nn.Module currently\n" diff --git a/colossalai/inference/executor/rpc_worker.py b/colossalai/inference/executor/rpc_worker.py index 913b8667d..a5199cb74 100644 --- a/colossalai/inference/executor/rpc_worker.py +++ b/colossalai/inference/executor/rpc_worker.py @@ -18,8 +18,9 @@ from colossalai.inference.modeling.policy import ( model_policy_map, ) from colossalai.inference.sampler import search_tokens -from colossalai.inference.utils import get_model_size +from colossalai.inference.utils import get_model_size, has_index_file from colossalai.interface import ModelWrapper +from colossalai.lazy import LazyInitContext from colossalai.logging import get_dist_logger from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer import ShardConfig, ShardFormer @@ -178,20 +179,23 @@ class rpcWorkerService(rpyc.Service): model_policy (Policy): the policy to replace the model """ + pretrained_path = None if isinstance(model_or_path, str): - # is_local = os.path.isdir(model_or_path) + import colossalai.interface.pretrained as pretrained_utils + try: - hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True) + hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True, torch_dtype=self.dtype) arch = getattr(hf_config, "architectures")[0] - # NOTE(lry89757) Currently we load the model using transformers-api, - # but we will use lazy tensor and checkpoint io to accelerate - # the model load process in the future. - model = _SUPPORTED_MODELS[arch].from_pretrained(model_or_path, trust_remote_code=True) - # if is_local: - # model = _SUPPORTED_MODELS[arch](hf_config) - # else: - # # load the real checkpoint - # model = _SUPPORTED_MODELS[arch].from_pretrained(model_or_path, trust_remote_code=True) + if arch is "BaichuanForCausalLM": + self.logger.warning( + "Attention ! We use lazy init by default, which could be faster for model loading. For baichuan model, the output maybe have a slight difference with transformers" + ) + ctx = LazyInitContext(default_device="cuda") + with ctx: + model = _SUPPORTED_MODELS[arch].from_pretrained( + model_or_path, trust_remote_code=True, torch_dtype=self.dtype + ) + pretrained_path = pretrained_utils.get_pretrained_path(model) except Exception as e: logger.error( f"An exception occurred during loading model: {e}, model should be loaded by transformers\n" @@ -240,14 +244,13 @@ class rpcWorkerService(rpyc.Service): f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}" ) - # NOTE(lry89757) Deprecated currently, will reused when introduce lazy tensor - # if isinstance(model_or_path, str) and is_local: - # from colossalai.inference.core.plugin import InferCheckpoint_io + if pretrained_path: + from colossalai.inference.core.plugin import InferCheckpoint_io - # cpt_io = InferCheckpoint_io() - # if_has_index_file, model_index_file = has_index_file(model_or_path) - # assert if_has_index_file, "the model path is invalid" - # cpt_io.load_model(self.model, model_index_file) + cpt_io = InferCheckpoint_io() + if_has_index_file, model_index_file = has_index_file(pretrained_path) + assert if_has_index_file, "the model path is invalid" + cpt_io.load_model(self.model, model_index_file) free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info() peak_memory = init_gpu_memory - free_gpu_memory diff --git a/colossalai/inference/modeling/layers/baichuan_tp_linear.py b/colossalai/inference/modeling/layers/baichuan_tp_linear.py index 50806a14b..75260f59b 100644 --- a/colossalai/inference/modeling/layers/baichuan_tp_linear.py +++ b/colossalai/inference/modeling/layers/baichuan_tp_linear.py @@ -1,8 +1,10 @@ from typing import List, Union +import torch.distributed as dist import torch.nn as nn from torch.distributed import ProcessGroup +from colossalai.lazy import LazyInitContext from colossalai.shardformer.layer import Linear1D_Col from colossalai.shardformer.layer.parallel_module import ParallelModule @@ -12,17 +14,51 @@ class BaichuanLMHeadLinear1D_Col(Linear1D_Col): def from_native_module( module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs ) -> ParallelModule: + LazyInitContext.materialize(module) module.in_features = module.weight.size(1) module.out_features = module.weight.size(0) module.bias = None module.weight.data = nn.functional.normalize( module.weight - ) # TODO(lry89757) This behavior may not apply to lazy init. When we use lazy init, the weight of shardformer is not the real weight. + ) # NOTE(lry89757) This behavior may not apply to lazy init. When we use lazy init, the weight of shardformer is not the real weight. # So we should rewrite our own load_from_state_dict of `BaichuanLMHeadLinear1D_Col` to fix this potential issue. - return Linear1D_Col.from_native_module( - module, - process_group, - *args, + # get the attributes + in_features = module.in_features + out_features = module.out_features + bias = module.bias is not None + device = module.weight.device + # ensure only one process group is passed + if isinstance(process_group, (list, tuple)): + assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}." + process_group = process_group[0] + + tp_size = dist.get_world_size(process_group) + if out_features < tp_size: + return module + + if out_features % tp_size != 0: + raise ValueError( + f"The size of out_features:{out_features} is not integer multiples of tensor parallel size: {tp_size}!" + ) + + lmhead_1d = BaichuanLMHeadLinear1D_Col( + in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + process_group=process_group, + weight=module.weight, + bias_=module.bias, **kwargs, ) + + return lmhead_1d + + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): + state_dict[prefix + "weight"] = nn.functional.normalize(state_dict[prefix + "weight"]) + super()._load_from_state_dict( + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ) diff --git a/colossalai/inference/modeling/models/nopadding_baichuan.py b/colossalai/inference/modeling/models/nopadding_baichuan.py index 3bab671c4..dfc53d9f6 100644 --- a/colossalai/inference/modeling/models/nopadding_baichuan.py +++ b/colossalai/inference/modeling/models/nopadding_baichuan.py @@ -70,7 +70,6 @@ class NopadBaichuanAttention(ParallelModule): attn_oproj (Linear1D_Row, optional): The Linear1D_Row o_proj. Defaults to None. """ ParallelModule.__init__(self) - self.o_proj = attn_oproj self.config = config self.num_heads = num_heads @@ -78,6 +77,7 @@ class NopadBaichuanAttention(ParallelModule): self.head_dim = self.hidden_size // self.num_heads self.process_group = process_group self.W_pack = W_pack + self.o_proj = attn_oproj self.use_cuda_kernel = model_shard_infer_config.use_cuda_kernel self.attention_backend = get_attention_backend(model_shard_infer_config) self.pre_attention_backend = get_pre_attention_backend(model_shard_infer_config) diff --git a/colossalai/inference/modeling/models/nopadding_llama.py b/colossalai/inference/modeling/models/nopadding_llama.py index 445ec59ce..c7c7473ac 100644 --- a/colossalai/inference/modeling/models/nopadding_llama.py +++ b/colossalai/inference/modeling/models/nopadding_llama.py @@ -284,6 +284,10 @@ class NopadLlamaMLP(LlamaMLP, ParallelModule): self.gate_up_weight = nn.Parameter( torch.stack([mlp_gproj_w.transpose(0, 1), mlp_uproj_w.transpose(0, 1)], dim=0) ) + self.gate_up_dict = { + "gate_proj.weight": None, + "up_proj.weight": None, + } # used and delattr in load/shard of gate/up weight self.down_proj = mlp_dproj self.process_group = process_group @@ -321,44 +325,47 @@ class NopadLlamaMLP(LlamaMLP, ParallelModule): ): # NOTE This is a hack to ensure we could load the right weight from LlamaMLP checkpoint due to the use of torch.stack(gate_weight, up_weight) - for hook in self._load_state_dict_pre_hooks.values(): - hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + if hasattr(self, "gate_up_dict"): + for hook in self._load_state_dict_pre_hooks.values(): + hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) - persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set} - local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items()) - local_state = {k: v for k, v in local_name_params if v is not None} + persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set} + local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items()) + local_state = {k: v for k, v in local_name_params if v is not None} - key = "gate_up_weight" - k1 = "gate_proj.weight" - k2 = "up_proj.weight" + device_mesh = self.helper_layout.device_mesh + sharding_spec = self.helper_layout.sharding_spec + for weight_name in self.gate_up_dict: + prefix_weight_name = prefix + weight_name + if prefix_weight_name in state_dict.keys(): + w = distribute_tensor(state_dict[prefix_weight_name], device_mesh, sharding_spec) + self.gate_up_dict[weight_name] = w.T - gate_w = state_dict[prefix + k1] - up_w = state_dict[prefix + k2] + if None not in self.gate_up_dict.values(): + # we've got all the weights of gate/up + gate_up_w = torch.stack(list(self.gate_up_dict.values()), dim=0) - device_mesh = self.helper_layout.device_mesh - sharding_spec = self.helper_layout.sharding_spec - gate_w = distribute_tensor(gate_w, device_mesh, sharding_spec) - up_w = distribute_tensor(up_w, device_mesh, sharding_spec) + input_param = nn.Parameter( + gate_up_w + ) # NOTE gate_up_weight doesn't have to be a distensor, Like input_param = sharded_tensor_to_param(input_param) - gate_up_w = torch.stack([gate_w.T, up_w.T], dim=0) + key = "gate_up_weight" + param = local_state.get(key, None) - input_param = nn.Parameter( - gate_up_w - ) # NOTE gate_up_weight doesn't have to be a distensor, Like input_param = sharded_tensor_to_param(input_param) - param = local_state[key] + try: + with torch.no_grad(): + param.copy_(input_param) + except Exception as ex: + error_msgs.append( + 'While copying the parameter named "{}", ' + "whose dimensions in the model are {} and " + "whose dimensions in the checkpoint are {}, " + "an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args) + ) - try: - with torch.no_grad(): - param.copy_(input_param) - except Exception as ex: - error_msgs.append( - 'While copying the parameter named "{}", ' - "whose dimensions in the model are {} and " - "whose dimensions in the checkpoint are {}, " - "an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args) - ) + del self.gate_up_dict - strict = False # to avoid unexpected_keys + strict = False # to avoid unexpected_keys super()._load_from_state_dict( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs ) @@ -429,7 +436,15 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule): self.helper_layout = ( attn_qproj_w.dist_layout ) # NOTE this is a hack for the right load/shard of qkv_weight(used in _load_from_state_dict) + self.qkv_dict = { + "q_proj.weight": None, + "k_proj.weight": None, + "v_proj.weight": None, + } # used and delattr in load/shard of qkv weight else: + self.helper_layout = ( + attn_qproj_w.dist_layout + ) # NOTE this is a hack for the right load/shard of qkv_weight(used in _load_from_state_dict) self.q_proj_weight = nn.Parameter(attn_qproj_w.transpose(0, 1).contiguous()) self.k_proj_weight = nn.Parameter(attn_kproj_w.transpose(0, 1).contiguous()) self.v_proj_weight = nn.Parameter(attn_vproj_w.transpose(0, 1).contiguous()) @@ -577,49 +592,83 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule): def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs ): - if self.num_heads == self.num_key_value_heads: + for hook in self._load_state_dict_pre_hooks.values(): + hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + + persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set} + local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items()) + local_state = {k: v for k, v in local_name_params if v is not None} + + device_mesh = self.helper_layout.device_mesh + sharding_spec = self.helper_layout.sharding_spec + + if self.num_heads == self.num_key_value_heads and hasattr(self, "qkv_dict"): # NOTE This is a hack to ensure we could load the right weight from LlamaAttention checkpoint due to the use of torch.stack(q_weight, k_weight, v_weight) - for hook in self._load_state_dict_pre_hooks.values(): - hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) - - persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set} - local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items()) - local_state = {k: v for k, v in local_name_params if v is not None} - key = "qkv_weight" - k1 = "q_proj.weight" - k2 = "k_proj.weight" - k3 = "v_proj.weight" - q_w = state_dict[prefix + k1] - k_w = state_dict[prefix + k2] - v_w = state_dict[prefix + k3] - device_mesh = self.helper_layout.device_mesh - sharding_spec = self.helper_layout.sharding_spec - q_w = distribute_tensor(q_w, device_mesh, sharding_spec) - k_w = distribute_tensor(k_w, device_mesh, sharding_spec) - v_w = distribute_tensor(v_w, device_mesh, sharding_spec) + # NOTE(@lry89757) We will load the sharded checkpoint file according to the weight map from *.index.json + # Here we need the weight of q,k,v to stack the weights of q,k,v into one qkv weight. + # Unfortunately, it is highly like that all weights of q,k,v are not in the same sharded checkpoint file(like meta-llama/llama3-70B) + # so here we will stack them when we really collect all the three weights. + for weight_name in self.qkv_dict: + prefix_weight_name = prefix + weight_name + if prefix_weight_name in state_dict.keys(): + w = distribute_tensor(state_dict[prefix_weight_name], device_mesh, sharding_spec) + self.qkv_dict[weight_name] = w.T - qkv_w = torch.stack([q_w.T, k_w.T, v_w.T], dim=0) + if None not in self.qkv_dict.values(): + # we've got all the weights of q, k, v + qkv_w = torch.stack(list(self.qkv_dict.values()), dim=0) - input_param = nn.Parameter( - qkv_w - ) # NOTE qkv_weight doesn't have to be a distensor, Like input_param = sharded_tensor_to_param(input_param) + input_param = nn.Parameter( + qkv_w + ) # NOTE qkv_weight doesn't have to be a distensor, Like input_param = sharded_tensor_to_param(input_param) - param = local_state[key] + param = local_state[key] - try: - with torch.no_grad(): - param.copy_(input_param) - except Exception as ex: - error_msgs.append( - 'While copying the parameter named "{}", ' - "whose dimensions in the model are {} and " - "whose dimensions in the checkpoint are {}, " - "an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args) - ) + try: + with torch.no_grad(): + param.copy_(input_param) + except Exception as ex: + error_msgs.append( + 'While copying the parameter named "{}", ' + "whose dimensions in the model are {} and " + "whose dimensions in the checkpoint are {}, " + "an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args) + ) - strict = False # to avoid unexpected_keys + del self.qkv_dict + + else: + + def _load(origin_weight_name="q_proj.weight", local_weight_name="q_proj_weight"): + if prefix + origin_weight_name in state_dict.keys(): + attn_qproj_w = state_dict[prefix + origin_weight_name] + w = distribute_tensor(attn_qproj_w, device_mesh, sharding_spec) + input_param = nn.Parameter(w.T) + param = local_state[local_weight_name] + try: + with torch.no_grad(): + param.copy_(input_param) + except Exception as ex: + key = local_weight_name + error_msgs.append( + 'While copying the parameter named "{}", ' + "whose dimensions in the model are {} and " + "whose dimensions in the checkpoint are {}, " + "an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args) + ) + + if prefix + "q_proj.weight" in state_dict.keys(): + _load(origin_weight_name="q_proj.weight", local_weight_name="q_proj_weight") + + if prefix + "k_proj.weight" in state_dict.keys(): + _load(origin_weight_name="k_proj.weight", local_weight_name="k_proj_weight") + + if prefix + "v_proj.weight" in state_dict.keys(): + _load(origin_weight_name="v_proj.weight", local_weight_name="v_proj_weight") + + strict = False # to avoid unexpected_keys super()._load_from_state_dict( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs ) diff --git a/colossalai/pipeline/p2p.py b/colossalai/pipeline/p2p.py index 1b55b140c..ed190eb08 100644 --- a/colossalai/pipeline/p2p.py +++ b/colossalai/pipeline/p2p.py @@ -225,31 +225,41 @@ def _batch_send_recv_tensor( send_group: Optional[ProcessGroup], recv_group: Optional[ProcessGroup], current_device: Any, + overlap_p2p: bool = True, + send_first: bool = True, ) -> Optional[Union[torch.Tensor, List[torch.Tensor]]]: buffer_recv = None if recv_tensor_metadata is not None: buffer_recv = _create_recv_buffer(recv_tensor_metadata, current_device) ops = [] - if send_dst is not None and send_tensor_list is not None: - assert send_group is not None - _filling_ops_queue(send_tensor_list, dist.isend, send_dst, ops, send_group) - if recv_src is not None and buffer_recv is not None: - assert recv_group is not None - _filling_ops_queue(buffer_recv, dist.irecv, recv_src, ops, recv_group) + is_send = send_dst is not None and send_tensor_list is not None + is_recv = recv_src is not None and buffer_recv is not None + + if send_first: + if is_send: + assert send_group is not None + _filling_ops_queue(send_tensor_list, dist.isend, send_dst, ops, send_group) + if is_recv: + assert recv_group is not None + _filling_ops_queue(buffer_recv, dist.irecv, recv_src, ops, recv_group) + else: + if is_recv: + assert recv_group is not None + _filling_ops_queue(buffer_recv, dist.irecv, recv_src, ops, recv_group) + if is_send: + assert send_group is not None + _filling_ops_queue(send_tensor_list, dist.isend, send_dst, ops, send_group) if len(ops) > 0: reqs = dist.batch_isend_irecv(ops) - for req in reqs: - req.wait() - - # Remove synchronization according to Pytorch's documentation - # However, the Megatron-LM does synchronization here - # https://github.com/microsoft/Megatron-DeepSpeed/blob/ef13d099c2a1609225a4ce4c1a1753cc76dd90a1/megatron/p2p_communication.py#L111-L112 - # In case there is potential error, uncomment the following `torch.cuda.synchronize()` - # torch.cuda.synchronize() - - return buffer_recv + if not overlap_p2p: + for req in reqs: + req.wait() + return buffer_recv, [] + else: + return buffer_recv, reqs + return None, [] def _send_recv_serialization_object( @@ -260,10 +270,11 @@ def _send_recv_serialization_object( recv_group: Optional[ProcessGroup], current_device: Any, is_nccl_backend: bool, + send_first: bool = True, ) -> Optional[P2PMetadata]: ops = [] - send_object_tensor = None + send_object_size_tensor = None if object is not None and send_dst is not None: if Version(torch.__version__) >= Version("1.13.0"): send_object_tensor, send_object_size_tensor = c10d._object_to_tensor(object, device=current_device) @@ -274,43 +285,54 @@ def _send_recv_serialization_object( send_object_size_tensor = send_object_size_tensor.to(current_device) send_object_tensor = send_object_tensor.to(current_device) - _filling_ops_queue(send_object_size_tensor, dist.isend, send_dst, ops, send_group) - recv_object_size_tensor = None if recv_src is not None: recv_object_size_tensor = torch.empty(1, dtype=torch.long) if is_nccl_backend: recv_object_size_tensor = recv_object_size_tensor.to(current_device) - _filling_ops_queue(recv_object_size_tensor, dist.irecv, recv_src, ops, recv_group) + + if send_first: + if send_object_size_tensor is not None: + _filling_ops_queue(send_object_size_tensor, dist.isend, send_dst, ops, send_group) + if recv_src is not None: + _filling_ops_queue(recv_object_size_tensor, dist.irecv, recv_src, ops, recv_group) + else: + if recv_src is not None: + _filling_ops_queue(recv_object_size_tensor, dist.irecv, recv_src, ops, recv_group) + if send_object_size_tensor is not None: + _filling_ops_queue(send_object_size_tensor, dist.isend, send_dst, ops, send_group) if len(ops) > 0: reqs = dist.batch_isend_irecv(ops) for req in reqs: - req.wait() - - # See the comment in `_batch_send_recv_tensor` - # torch.cuda.synchronize() + req.wait() # This blocks the compute stream in torch ops = [] - - if send_dst is not None and send_object_tensor is not None: - _filling_ops_queue(send_object_tensor, dist.isend, send_dst, ops, send_group) + is_send = send_dst is not None and send_object_tensor is not None + is_recv = recv_src is not None and recv_object_size_tensor is not None recv_object_tensor = None - if recv_src is not None and recv_object_size_tensor is not None: + if is_recv: recv_object_tensor = torch.empty(recv_object_size_tensor.item(), dtype=torch.uint8) if is_nccl_backend: recv_object_tensor = recv_object_tensor.to(current_device) - _filling_ops_queue(recv_object_tensor, dist.irecv, recv_src, ops, recv_group) + + if send_first: + if is_send: + _filling_ops_queue(send_object_tensor, dist.isend, send_dst, ops, send_group) + if is_recv: + _filling_ops_queue(recv_object_tensor, dist.irecv, recv_src, ops, recv_group) + else: + if is_recv: + _filling_ops_queue(recv_object_tensor, dist.irecv, recv_src, ops, recv_group) + if is_send: + _filling_ops_queue(send_object_tensor, dist.isend, send_dst, ops, send_group) if len(ops) > 0: reqs = dist.batch_isend_irecv(ops) for req in reqs: req.wait() - # See the comment in `_batch_send_recv_tensor` - # torch.cuda.synchronize() - if recv_object_tensor is not None and recv_object_size_tensor is not None: recv_object_tensor = recv_object_tensor.type(torch.uint8) if recv_object_tensor.device != torch.device("cpu"): @@ -328,11 +350,12 @@ def _communicate( object: Any, send_dst: Optional[int], recv_src: Optional[int], + overlap_p2p: bool, send_group: Optional[ProcessGroup] = None, recv_group: Optional[ProcessGroup] = None, send_metadata: bool = True, metadata_recv: Optional[P2PMetadata] = None, - send_prior_fallback: Optional[bool] = None, + send_first: Optional[bool] = None, ) -> Any: """ Send and receive object from send_dst and recv_src respectively @@ -341,6 +364,7 @@ def _communicate( object (Any): object needed to be sent send_dst (int): rank of the destination recv_src (int): rank of the source + overlap_p2p (bool): whether to overlap p2p communication with computation send_group (ProcessGroup, optional): process group of sender recv_group (ProcessGroup, optional): process group of receiver send_metadata (bool, optional): whether to send metadata @@ -358,32 +382,10 @@ def _communicate( # NOTE: if object contains non-tensor objects, we have to send metadata metadata_send, tensor_objs = create_send_metadata(object, strict=False, return_tensor=True) send_metadata = send_metadata or len(metadata_send.non_tensor_obj_idx) > 0 + else: + send_metadata = False - # NOTE: send & recv should be atomic operations. However, if we need to send metadata or receive metadata, - # we are not able to do that (1. send & recv metadata 2. send & recv). So we need to split the send & recv into two parts in this case. - if (send_dst is not None and recv_src is not None) and (send_metadata or metadata_recv is None): - assert send_prior_fallback is not None, "Priority must be set if fallback happens" - if send_prior_fallback: - _communicate(object, send_dst=send_dst, recv_src=None, send_group=send_group, send_metadata=send_metadata) - return _communicate( - None, send_dst=None, recv_src=recv_src, recv_group=recv_group, metadata_recv=metadata_recv - ) - else: - recv_data = _communicate( - None, send_dst=None, recv_src=recv_src, recv_group=recv_group, metadata_recv=metadata_recv - ) - _communicate(object, send_dst=send_dst, recv_src=None, send_group=send_group, send_metadata=send_metadata) - return recv_data - - # NOTE: only the following 5 cases are valid: - # 1. send() [needs extra metadata] and no recv() - # 2. recv() [needs extra metadata] and no send() - # 3. neither send() nor recv() need extra metadata - assert not (send_dst is not None and send_metadata) or recv_src is None - assert not (recv_src is not None and metadata_recv is None) or send_dst is None - assert not (send_dst is not None and recv_src is not None) or (not send_metadata and metadata_recv is not None) assert not c10d._rank_not_in_group(send_group) and not c10d._rank_not_in_group(recv_group) - current_send_device, is_send_nccl_backend = _check_device(send_group) current_recv_device, is_recv_nccl_backend = _check_device(recv_group) @@ -402,14 +404,25 @@ def _communicate( recv_group=recv_group if metadata_recv is None else None, current_device=current_device, is_nccl_backend=is_nccl_backend, + send_first=send_first if send_first != None else True, ) - assert metadata_recv is None or _metadata_recv is None + assert ( + metadata_recv is None or _metadata_recv is None + ), "You shouldn't receive metadata when using the cached metadata" metadata_recv = _metadata_recv if metadata_recv is None else metadata_recv # Send and receive data recv_tensor_metadata = None if metadata_recv is None else metadata_recv.tensor_metadata - recv_tensor_objs = _batch_send_recv_tensor( - tensor_objs, recv_tensor_metadata, send_dst, recv_src, send_group, recv_group, current_device + recv_tensor_objs, wait_handles = _batch_send_recv_tensor( + tensor_objs, + recv_tensor_metadata, + send_dst, + recv_src, + send_group, + recv_group, + current_device, + overlap_p2p=overlap_p2p, + send_first=send_first if send_first != None else True, ) if metadata_recv is not None: @@ -424,33 +437,9 @@ def _communicate( for idx in non_tensor_obj_idx: recv_tensor_objs.insert(idx, non_tensor_objs.pop(0)) recv_object = tree_unflatten(recv_tensor_objs, tree_spec) + return recv_object, wait_handles - return recv_object - - -def _send_object(object: Any, src: int, dst: int, group: ProcessGroup, **kwargs) -> None: - """send anything to dst rank - - Args: - object (Any): object needed to be sent - dst (int): rank of the destination - - Returns: - None - """ - _communicate(object, send_dst=dst, recv_src=None, send_group=group, **kwargs) - - -def _recv_object(src: int, dst: int, group: ProcessGroup, **kwargs) -> Any: - """recv anything from src - - Args: - src (int): source rank of data. local rank will receive data from src rank. - - Returns: - Any: Object received from src. - """ - return _communicate(None, send_dst=None, recv_src=src, recv_group=group, **kwargs) + return None, wait_handles def _p2p_comm( @@ -532,10 +521,13 @@ def _p2p_comm( class PipelineP2PCommunication: - def __init__(self, stage_manager: PipelineStageManager) -> None: + def __init__(self, stage_manager: PipelineStageManager, overlap_p2p: bool = True) -> None: self.stage_manager = stage_manager + self.overlap_p2p = overlap_p2p - def recv_forward(self, prev_rank: Optional[int] = None, metadata_recv: Optional[P2PMetadata] = None) -> Any: + def recv_forward( + self, prev_rank: Optional[int] = None, metadata_recv: Optional[P2PMetadata] = None + ) -> Tuple[Any, List]: """Copy the forward output from the previous stage in pipeline as the input tensor of this stage. Args: @@ -543,95 +535,186 @@ class PipelineP2PCommunication: Returns: Any: The input tensor or input tensor list. + List: List of handles for the communication requests, if overlap is enabled. """ if prev_rank is None: prev_rank = self.stage_manager.get_prev_rank() - cur_rank = self.stage_manager.get_rank() - input_tensor = _recv_object( - prev_rank, - cur_rank, - self.stage_manager.get_p2p_process_group(prev_rank, cur_rank), + input_tensor, wait_handles = _communicate( + object=None, + recv_src=prev_rank, + send_dst=None, + recv_group=self.stage_manager.get_p2p_process_group(), metadata_recv=metadata_recv, + overlap_p2p=self.overlap_p2p, ) - return input_tensor + return input_tensor, wait_handles - def recv_backward(self, next_rank: Optional[int] = None, metadata_recv: Optional[P2PMetadata] = None) -> Any: + def recv_backward( + self, next_rank: Optional[int] = None, metadata_recv: Optional[P2PMetadata] = None + ) -> Tuple[Any, List]: """Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage. - Args: next_rank (int, optional): The rank of the source of the tensor. Returns: - Any: The input gradient tensor or gradient tensor list. + Any: The input tensor or input tensor list. + List: List of handles for the communication requests, if overlap is enabled. """ if next_rank is None: next_rank = self.stage_manager.get_next_rank() - cur_rank = self.stage_manager.get_rank() - output_tensor_grad = _recv_object( - next_rank, - cur_rank, - self.stage_manager.get_p2p_process_group(next_rank, cur_rank), + + output_tensor_grad, wait_handles = _communicate( + object=None, + recv_src=next_rank, + send_dst=None, + recv_group=self.stage_manager.get_p2p_process_group(), metadata_recv=metadata_recv, + overlap_p2p=self.overlap_p2p, ) - return output_tensor_grad + return output_tensor_grad, wait_handles - def send_forward(self, output_object: Any, next_rank: Optional[int] = None, send_metadata: bool = True) -> None: + def send_forward(self, output_object: Any, next_rank: Optional[int] = None, send_metadata: bool = True) -> List: """Sends the input tensor to the next stage in pipeline. Args: output_object (Any): Object to be sent. next_rank (int, optional): The rank of the recipient of the tensor. + + Returns: + List: List of handles for the communication requests, if overlap is enabled. """ if next_rank is None: next_rank = self.stage_manager.get_next_rank() - cur_rank = self.stage_manager.get_rank() - _send_object( + _, handles = _communicate( output_object, - cur_rank, - next_rank, - self.stage_manager.get_p2p_process_group(cur_rank, next_rank), + recv_src=None, + send_dst=next_rank, + send_group=self.stage_manager.get_p2p_process_group(), send_metadata=send_metadata, + overlap_p2p=self.overlap_p2p, ) + return handles - def send_backward(self, input_object: Any, prev_rank: Optional[int] = None, send_metadata: bool = True) -> None: + def send_backward(self, input_object: Any, prev_rank: Optional[int] = None, send_metadata: bool = True) -> List: """Sends the gradient tensor to the previous stage in pipeline. Args: input_object (Any): Object to be sent. prev_rank (int, optional): The rank of the recipient of the tensor + + Returns: + List: List of handles for the communication requests, if overlap is enabled. """ if prev_rank is None: prev_rank = self.stage_manager.get_prev_rank() - cur_rank = self.stage_manager.get_rank() - _send_object( + _, handles = _communicate( input_object, - cur_rank, - prev_rank, - self.stage_manager.get_p2p_process_group(cur_rank, prev_rank), + recv_src=None, + send_dst=prev_rank, + send_group=self.stage_manager.get_p2p_process_group(), send_metadata=send_metadata, + overlap_p2p=self.overlap_p2p, + ) + return handles + + def send_forward_recv_forward( + self, + output_object: Any, + is_send: bool, + is_recv: bool, + send_first: bool, + send_metadata: bool = True, + metadata_recv: Optional[P2PMetadata] = None, + ) -> Tuple[Any, List]: + """Sends the input tensor to the next pipeline stage and copy the output tensor from the next pipeline stage + + Args: + output_object (Any): Object to be sent. + is_send (bool): Whether to send the input tensor to the next pipeline stage. + is_recv (bool): Whether to copy the output tensor from the next pipeline stage. + send_first (bool): Whether to send before receive. + send_metadata (bool, optional): Whether to send metadata. + metadata_recv (P2PMetadata, optional): The cached metadata(size, type) of the object to be received. + + Returns: + Any: The input tensor or input tensor list. + List: List of handles for the communication requests, if overlap is enabled. + """ + next_rank = self.stage_manager.get_next_rank() if is_send else None + prev_rank = self.stage_manager.get_prev_rank() if is_recv else None + group = self.stage_manager.get_p2p_process_group() + return _communicate( + output_object, + send_dst=next_rank, + recv_src=prev_rank, + send_group=group if is_send else None, + recv_group=group if is_recv else None, + send_metadata=send_metadata if is_send else False, + metadata_recv=metadata_recv if is_recv else None, + send_first=send_first, + overlap_p2p=self.overlap_p2p, + ) + + def send_backward_recv_backward( + self, + input_object: Any, + is_send: bool, + is_recv: bool, + send_first: bool, + send_metadata: bool = True, + metadata_recv: Optional[P2PMetadata] = None, + ) -> Tuple[Any, List]: + """Sends the gradient tensor to the previous pipeline stage and copy the gradient tensor from the previous pipeline stage + + Args: + input_object (Any): Object to be sent. + is_send (bool): Whether to send the gradient tensor to the previous pipeline stage. + is_recv (bool): Whether to copy the gradient tensor from the previous pipeline stage. + send_first (bool): Whether to send before receive. + send_metadata (bool, optional): Whether to send metadata. + metadata_recv (P2PMetadata, optional): The cached metadata(size, type) of the object to be received. + + Returns: + Any: The input tensor or input tensor list. + List: List of handles for the communication requests, if overlap is enabled. + """ + prev_rank = self.stage_manager.get_prev_rank() if is_send else None + next_rank = self.stage_manager.get_next_rank() if is_recv else None + + group = self.stage_manager.get_p2p_process_group() + + return _communicate( + input_object, + send_dst=prev_rank, + recv_src=next_rank, + send_group=group if is_send else None, + recv_group=group if is_recv else None, + send_metadata=send_metadata if is_send else False, + metadata_recv=metadata_recv if is_recv else None, + send_first=send_first, + overlap_p2p=self.overlap_p2p, ) def send_forward_recv_backward( self, input_object: Any, - next_rank: Optional[int] = None, send_metadata: bool = True, metadata_recv: Optional[P2PMetadata] = None, - send_prior_fallback: Optional[bool] = None, - ) -> Any: - """Sends the gradient tensor to and copy the gradient tensor from the next stage in pipeline + send_first: Optional[bool] = None, + ) -> Tuple[Any, List]: + """Sends the gradient tensor to and copy the gradient tensor from the next pipeline stage Args: input_object (Any): Object to be sent. - next_rank (int, optional): The rank of the sender and recipient of the tensor - """ - if next_rank is None: - next_rank = self.stage_manager.get_next_rank() - cur_rank = self.stage_manager.get_rank() - group = self.stage_manager.get_p2p_process_group(cur_rank, next_rank) + Returns: + Any: The input tensor or input tensor list. + List: List of handles for the communication requests, if overlap is enabled. + """ + next_rank = self.stage_manager.get_next_rank() + group = self.stage_manager.get_p2p_process_group() return _communicate( input_object, next_rank, @@ -640,28 +723,28 @@ class PipelineP2PCommunication: recv_group=group, send_metadata=send_metadata, metadata_recv=metadata_recv, - send_prior_fallback=send_prior_fallback, + send_first=send_first, + overlap_p2p=False, ) def send_backward_recv_forward( self, input_object: Any, - prev_rank: Optional[int] = None, send_metadata: bool = True, metadata_recv: Optional[P2PMetadata] = None, - send_prior_fallback: Optional[bool] = None, - ) -> Any: + send_first: Optional[bool] = None, + ) -> Tuple[Any, List]: """Sends the gradient tensor to and copy the gradient tensor from the previous stage in pipeline Args: input_object (Any): Object to be sent. - prev_rank (int, optional): The rank of the sender and recipient of the tensor - """ - if prev_rank is None: - prev_rank = self.stage_manager.get_prev_rank() - cur_rank = self.stage_manager.get_rank() - group = self.stage_manager.get_p2p_process_group(prev_rank, cur_rank) + Returns: + Any: The input tensor or input tensor list. + List: List of handles for the communication requests, if overlap is enabled. + """ + prev_rank = self.stage_manager.get_prev_rank() + group = self.stage_manager.get_p2p_process_group() return _communicate( input_object, prev_rank, @@ -670,7 +753,8 @@ class PipelineP2PCommunication: recv_group=group, send_metadata=send_metadata, metadata_recv=metadata_recv, - send_prior_fallback=send_prior_fallback, + send_first=send_first, + overlap_p2p=False, ) def p2p_communicate( @@ -679,7 +763,7 @@ class PipelineP2PCommunication: recv_pre: bool, next_rank: Optional[int] = None, comm_dtype: torch.dtype = torch.float16, - ) -> None: + ) -> Any: """ Sends the input tensor to the next stage in pipeline, using `P2Pop` in torch. @@ -689,12 +773,11 @@ class PipelineP2PCommunication: """ if next_rank is None: next_rank = self.stage_manager.get_next_rank() - cur_rank = self.stage_manager.get_rank() recv_tensor = _p2p_comm( output_object, recv_pre, next_rank, - self.stage_manager.get_p2p_process_group(cur_rank, next_rank), + self.stage_manager.get_p2p_process_group(), comm_dtype, ) return recv_tensor diff --git a/colossalai/pipeline/schedule/interleaved_pp.py b/colossalai/pipeline/schedule/interleaved_pp.py index a4ace5e1b..a21b45c44 100644 --- a/colossalai/pipeline/schedule/interleaved_pp.py +++ b/colossalai/pipeline/schedule/interleaved_pp.py @@ -1,8 +1,9 @@ from functools import partial -from typing import Any, Callable, Dict, Iterable, List, Optional, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union import torch import torch.cuda +import torch.distributed from torch.nn import Module, ModuleList from torch.utils._pytree import tree_map @@ -16,6 +17,12 @@ from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_ from .base import PipelineSchedule +def _wait_p2p(wait_handles: List[torch.cuda.Event]) -> None: + if wait_handles is not None: + for req in wait_handles: + req.wait() + + class InterleavedSchedule(PipelineSchedule): def __init__( self, @@ -24,13 +31,15 @@ class InterleavedSchedule(PipelineSchedule): num_microbatch: Optional[int] = None, microbatch_size: Optional[int] = None, enable_metadata_cache: bool = True, + overlap_p2p: bool = True, ) -> None: super().__init__(stage_manager) assert ( num_microbatch is not None or microbatch_size is not None ), "Either num_microbatch or microbatch_size should be provided" - self.comm = PipelineP2PCommunication(stage_manager) + self.comm = PipelineP2PCommunication(stage_manager, overlap_p2p=overlap_p2p) + self.overlap_p2p = overlap_p2p self.num_microbatch = num_microbatch self.microbatch_size = microbatch_size self.num_model_chunks = num_model_chunks @@ -113,14 +122,17 @@ class InterleavedSchedule(PipelineSchedule): Returns: int: The model chunk idx of the input microbatch_id """ - assert microbatch_id < self.num_microbatch * self.num_model_chunks + assert ( + microbatch_id < self.num_microbatch * self.num_model_chunks + ), f"microbatch_id {microbatch_id} is out of range ({self.num_microbatch * self.num_model_chunks})" microbatch_id_in_group = microbatch_id % (self.stage_manager.num_stages * self.num_model_chunks) model_chunk_id = microbatch_id_in_group // self.stage_manager.num_stages if not is_forward: + # Reverse order model_chunk_id = self.num_model_chunks - model_chunk_id - 1 return model_chunk_id - def recv_forward(self, model_chunk_id: int, prev_rank: int = None) -> Any: + def recv_forward(self, model_chunk_id: int, prev_rank: int = None) -> Tuple[Any, List]: """Copy the forward output from the previous stage in pipeline as the input tensor of this stage. For interleaved 1F1B. @@ -130,16 +142,19 @@ class InterleavedSchedule(PipelineSchedule): Returns: Any: The input tensor or input tensor list. + Any: The wait handles for the communication. """ with self.stage_manager.switch_model_chunk_id(model_chunk_id): if not self.stage_manager.is_first_stage(): - input_tensor = self.comm.recv_forward(prev_rank, metadata_recv=self.tensor_metadata_recv) + input_tensor, wait_handles = self.comm.recv_forward(prev_rank, metadata_recv=self.tensor_metadata_recv) + if self.enable_metadata_cache and self.tensor_metadata_recv is None: self.tensor_metadata_recv = create_send_metadata(input_tensor) - return input_tensor + return input_tensor, wait_handles + return None, [] - def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> Any: + def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> Tuple[Any, List]: """Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage. For interleaved 1F1B. @@ -149,16 +164,20 @@ class InterleavedSchedule(PipelineSchedule): Returns: Any: The input gradient tensor or gradient tensor list. + Any: The wait handles for the communication. """ with self.stage_manager.switch_model_chunk_id(model_chunk_id): if not self.stage_manager.is_last_stage(): - output_tensor_grad = self.comm.recv_backward(next_rank, metadata_recv=self.grad_metadata_recv) + output_tensor_grad, wait_handles = self.comm.recv_backward( + next_rank, metadata_recv=self.grad_metadata_recv + ) if self.enable_metadata_cache and self.grad_metadata_recv is None: self.grad_metadata_recv = create_send_metadata(output_tensor_grad) + return output_tensor_grad, wait_handles - return output_tensor_grad + return None, [] - def send_forward(self, model_chunk_id: int, output_tensor: Any, next_rank: int = None) -> None: + def send_forward(self, model_chunk_id: int, output_tensor: Any, next_rank: int = None) -> List: """Sends the input tensor to the next stage in pipeline. For interleaved 1F1B. @@ -166,13 +185,18 @@ class InterleavedSchedule(PipelineSchedule): model_chunk_id (int): The current model chunk idx. output_object (Any): Object to be sent. next_rank (int, optional): The rank of the recipient of the tensor. + + Returns: + Any: The wait handles for the communication. """ with self.stage_manager.switch_model_chunk_id(model_chunk_id): if not self.stage_manager.is_last_stage(): - self.comm.send_forward(output_tensor, next_rank, send_metadata=self.send_tensor_metadata) + send_handles = self.comm.send_forward(output_tensor, next_rank, send_metadata=self.send_tensor_metadata) self.send_tensor_metadata = not self.enable_metadata_cache + return send_handles + return [] - def send_backward(self, model_chunk_id: int, input_tensor_grad: Any, prev_rank: int = None) -> None: + def send_backward(self, model_chunk_id: int, input_tensor_grad: Any, prev_rank: int = None) -> List: """Sends the gradient tensor to the previous stage in pipeline. For interleaved 1F1B. @@ -180,99 +204,61 @@ class InterleavedSchedule(PipelineSchedule): model_chunk_id (int): The current model chunk idx. input_object (Any): Object to be sent. prev_rank (int, optional): The rank of the recipient of the tensor + + Returns: + Any: The wait handles for the communication. """ with self.stage_manager.switch_model_chunk_id(model_chunk_id): if not self.stage_manager.is_first_stage(): - self.comm.send_backward(input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata) + send_handles = self.comm.send_backward( + input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata + ) self.send_grad_metadata = not self.enable_metadata_cache - - def send_forward_recv_backward( - self, - model_chunk_id_send: int, - model_chunk_id_recv: int, - output_tensor: Any, - next_rank: Optional[int] = None, - send_prior_fallback: Optional[bool] = None, - ) -> Any: - with self.stage_manager.switch_model_chunk_id(model_chunk_id_send): - send_data = not self.stage_manager.is_last_stage() - with self.stage_manager.switch_model_chunk_id(model_chunk_id_recv): - recv_data = not self.stage_manager.is_last_stage() - - if send_data and recv_data: - if not self.send_forward_recv_backward and self.grad_metadata_recv is not None: - send_prior_fallback = None # must not fallback - output_tensor_grad = self.comm.send_forward_recv_backward( - output_tensor, - next_rank, - send_metadata=self.send_tensor_metadata, - metadata_recv=self.grad_metadata_recv, - send_prior_fallback=send_prior_fallback, - ) - self.send_tensor_metadata = not self.enable_metadata_cache - if self.enable_metadata_cache and self.grad_metadata_recv is None: - self.grad_metadata_recv = create_send_metadata(output_tensor_grad) - return output_tensor_grad - - # send only or recv only - self.send_forward(model_chunk_id_send, output_tensor) - return self.recv_backward(model_chunk_id_recv) - - def send_backward_recv_forward( - self, - model_chunk_id_send: int, - model_chunk_id_recv: int, - input_tensor_grad: Any, - prev_rank: Optional[int] = None, - send_prior_fallback: Optional[bool] = None, - ) -> Any: - with self.stage_manager.switch_model_chunk_id(model_chunk_id_send): - send_data = not self.stage_manager.is_first_stage() - with self.stage_manager.switch_model_chunk_id(model_chunk_id_recv): - recv_data = not self.stage_manager.is_first_stage() - - if send_data and recv_data: - if not self.send_backward_recv_backward and self.tensor_metadata_recv is not None: - send_prior_fallback = None # must not fallback - input_tensor = self.comm.send_backward_recv_forward( - input_tensor_grad, - prev_rank, - send_metadata=self.send_grad_metadata, - metadata_recv=self.tensor_metadata_recv, - send_prior_fallback=send_prior_fallback, - ) - self.send_grad_metadata = not self.enable_metadata_cache - if self.enable_metadata_cache and self.tensor_metadata_recv is None: - self.tensor_metadata_recv = create_send_metadata(input_tensor) - return input_tensor - - # send only or recv only - self.send_backward(model_chunk_id_send, input_tensor_grad) - return self.recv_forward(model_chunk_id_recv) + return send_handles + return [] def send_forward_recv_forward( - self, model_chunk_id_send: int, model_chunk_id_recv: int, output_tensor: Any, send_prior: bool - ): - if send_prior: - self.send_forward(model_chunk_id_send, output_tensor) - input_tensor = self.recv_forward(model_chunk_id_recv) - else: - input_tensor = self.recv_forward(model_chunk_id_recv) - self.send_forward(model_chunk_id_send, output_tensor) + self, model_chunk_id_send: int, model_chunk_id_recv: int, output_tensor: Any, send_first: bool = True + ) -> Tuple[Any, List]: + with self.stage_manager.switch_model_chunk_id(model_chunk_id_send): + is_send = not self.stage_manager.is_last_stage() + with self.stage_manager.switch_model_chunk_id(model_chunk_id_recv): + is_recv = not self.stage_manager.is_first_stage() + input_tensor, wait_handles = self.comm.send_forward_recv_forward( + output_tensor, + is_send, + is_recv, + send_metadata=self.send_tensor_metadata, + metadata_recv=self.tensor_metadata_recv, + send_first=send_first, + ) + # Cache metadata + self.send_tensor_metadata = not self.enable_metadata_cache and is_send + if is_recv and self.enable_metadata_cache and self.tensor_metadata_recv is None: + self.tensor_metadata_recv = create_send_metadata(input_tensor) - return input_tensor + return input_tensor, wait_handles def send_backward_recv_backward( - self, model_chunk_id_send: int, model_chunk_id_recv: int, input_tensor_grad: Any, send_prior: bool - ): - if send_prior: - self.send_backward(model_chunk_id_send, input_tensor_grad) - output_tensor_grad = self.recv_backward(model_chunk_id_recv) - else: - output_tensor_grad = self.recv_backward(model_chunk_id_recv) - self.send_backward(model_chunk_id_send, input_tensor_grad) - - return output_tensor_grad + self, model_chunk_id_send: int, model_chunk_id_recv: int, input_tensor_grad: Any, send_first: bool = True + ) -> Tuple[Any, List]: + with self.stage_manager.switch_model_chunk_id(model_chunk_id_send): + is_send = not self.stage_manager.is_first_stage() + with self.stage_manager.switch_model_chunk_id(model_chunk_id_recv): + is_recv = not self.stage_manager.is_last_stage() + output_tensor_grad, wait_handles = self.comm.send_backward_recv_backward( + input_tensor_grad, + is_send, + is_recv, + send_metadata=self.send_grad_metadata, + metadata_recv=self.grad_metadata_recv, + send_first=send_first, + ) + # Cache metadata + self.send_grad_metadata = not self.enable_metadata_cache and is_send + if is_recv and self.enable_metadata_cache and self.grad_metadata_recv is None: + self.grad_metadata_recv = create_send_metadata(output_tensor_grad) + return output_tensor_grad, wait_handles def forward_step( self, @@ -294,10 +280,12 @@ class InterleavedSchedule(PipelineSchedule): Returns: Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor). """ + # Load input ids, attention mask and labels micro_batch = self.load_micro_batch(model_chunk_id=model_chunk_id) # for the first stage, input_obj is None - # for the non-first stage, input_obj is the output of the previous stage and it's must be a dict + # for other stages, input_obj is the output of the previous stage containing hidden_states etc. + # Only attention_mask from micro_batch is used with self.stage_manager.switch_model_chunk_id(model_chunk_id): if isinstance(model_chunk, ModuleList): @@ -381,23 +369,27 @@ class InterleavedSchedule(PipelineSchedule): if return_loss and self.stage_manager.is_last_stage(ignore_chunk=True): accum_loss = torch.scalar_tensor(0, device=get_current_device()) + fwd_wait_handles = [] model_chunk_id = self.get_model_chunk_id(0, is_forward=True) - input_obj = self.recv_forward(model_chunk_id) + input_obj, fwd_wait_handles = self.recv_forward(model_chunk_id) for i in range(self.num_microbatch * self.num_model_chunks): - last_iteration = i == self.num_microbatch * self.num_model_chunks - 1 + last_batch = i == self.num_microbatch * self.num_model_chunks - 1 model_chunk_id = self.get_model_chunk_id(i, is_forward=True) + + # Wait until current input is received + _wait_p2p(fwd_wait_handles) output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs) - if not last_iteration: - input_obj = self.send_forward_recv_forward( + if not last_batch: + input_obj, fwd_wait_handles = self.send_forward_recv_forward( model_chunk_id_send=model_chunk_id, model_chunk_id_recv=self.get_model_chunk_id(i + 1, is_forward=True), output_tensor=output_obj, - send_prior=self.stage_manager.stage % 2 == 0, + send_first=self.stage_manager.stage % 2 == 0, ) else: - self.send_forward(model_chunk_id, output_obj) + fwd_wait_handles = self.send_forward(model_chunk_id, output_obj) if outputs is not None: outputs = merge_batch(outputs) @@ -420,7 +412,9 @@ class InterleavedSchedule(PipelineSchedule): self.load_batch(data_iter) num_microbatch = self.num_microbatch * self.num_model_chunks + # Forward + until 1st backward num_warmup_microbatch = (self.stage_manager.num_stages - self.stage_manager.stage - 1) * 2 + # Steps needed to reach the last chunk num_warmup_microbatch += (self.num_model_chunks - 1) * self.stage_manager.num_stages num_warmup_microbatch = min(num_warmup_microbatch, num_microbatch) num_microbatch_remaining = num_microbatch - num_warmup_microbatch @@ -435,35 +429,44 @@ class InterleavedSchedule(PipelineSchedule): if return_loss and self.stage_manager.is_last_stage(ignore_chunk=True): accum_loss = torch.scalar_tensor(0, device=get_current_device()) + bwd_wait_handles = [] + # Get the 1st input batch model_chunk_id = self.get_model_chunk_id(0, is_forward=True) - input_obj = self.recv_forward(model_chunk_id) + input_obj, fwd_wait_handles = self.recv_forward(model_chunk_id) + # Run warmup forward passes. for i in range(num_warmup_microbatch): - last_iteration = i == num_warmup_microbatch - 1 + last_batch = i == num_warmup_microbatch - 1 model_chunk_id = self.get_model_chunk_id(i, is_forward=True) + + # Wait for input + _wait_p2p(fwd_wait_handles) output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs) input_objs[model_chunk_id].append(input_obj) output_objs[model_chunk_id].append(output_obj) - if last_iteration and num_microbatch_remaining == 0: - self.send_forward(model_chunk_id, output_obj) + if last_batch and num_microbatch_remaining == 0: + fwd_wait_handles = self.send_forward(model_chunk_id, output_obj) else: - input_obj = self.send_forward_recv_forward( + input_obj, fwd_wait_handles = self.send_forward_recv_forward( model_chunk_id_send=model_chunk_id, model_chunk_id_recv=self.get_model_chunk_id(i + 1, is_forward=True), output_tensor=output_obj, - send_prior=self.stage_manager.stage % 2 == 0, + send_first=self.stage_manager.stage % 2 == 0, ) if num_microbatch_remaining > 0: model_chunk_id = self.get_model_chunk_id(0, is_forward=False) - output_obj_grad = self.recv_backward(model_chunk_id) + output_obj_grad, bwd_wait_handles = self.recv_backward(model_chunk_id) # Run 1F1B in steady state. for i in range(num_microbatch_remaining): - last_iteration = i == num_microbatch_remaining - 1 + fwd_batch_id = i + num_warmup_microbatch + last_batch = i == num_microbatch_remaining - 1 + model_chunk_id = self.get_model_chunk_id(fwd_batch_id, is_forward=True) - model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatch, is_forward=True) + # Wait for input. + _wait_p2p(fwd_wait_handles) output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs) # Add input_obj and output_obj to end of list. input_objs[model_chunk_id].append(input_obj) @@ -473,64 +476,75 @@ class InterleavedSchedule(PipelineSchedule): # Pop output_obj and output_obj from the start of the list for the backward pass. _input_obj = input_objs[model_chunk_id].pop(0) _output_obj = output_objs[model_chunk_id].pop(0) - input_obj_grad = self.backward_step(optimizer, _input_obj, _output_obj, output_obj_grad) - # NOTE: perform 2x communication for forward and backward - def send_forward_recv_backward(): - if last_iteration and num_microbatch == num_microbatch_remaining: - model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatch, is_forward=True) - self.send_forward(model_chunk_id, output_obj) + # Helper functions + def send_forward_recv_forward(): + if last_batch: + model_chunk_id = self.get_model_chunk_id(fwd_batch_id, is_forward=True) + wait_handles = self.send_forward(model_chunk_id, output_obj) + return None, wait_handles else: - output_obj_grad = self.send_forward_recv_backward( - model_chunk_id_send=self.get_model_chunk_id(i + num_warmup_microbatch, is_forward=True), - model_chunk_id_recv=self.get_model_chunk_id(i + 1, is_forward=False), + input_obj, wait_handles = self.send_forward_recv_forward( + model_chunk_id_send=self.get_model_chunk_id(fwd_batch_id, is_forward=True), + model_chunk_id_recv=self.get_model_chunk_id(fwd_batch_id + 1, is_forward=True), output_tensor=output_obj, - send_prior_fallback=self.stage_manager.stage % 2 == 0, + send_first=self.stage_manager.stage % 2 == 0 + and i > 0, # Receive from warmup stage first in the first batch ) - return output_obj_grad + return input_obj, wait_handles - def send_backward_recv_forward(): - if last_iteration: + def send_backward_recv_backward(): + no_cooldown = num_microbatch == num_microbatch_remaining + if last_batch and no_cooldown: model_chunk_id = self.get_model_chunk_id(i, is_forward=False) - self.send_backward(model_chunk_id, input_obj_grad) + wait_handles = self.send_backward(model_chunk_id, input_obj_grad) + return None, wait_handles else: - input_obj = self.send_backward_recv_forward( + output_obj_grad, wait_handles = self.send_backward_recv_backward( model_chunk_id_send=self.get_model_chunk_id(i, is_forward=False), - model_chunk_id_recv=self.get_model_chunk_id(i + num_warmup_microbatch + 1, is_forward=True), + model_chunk_id_recv=self.get_model_chunk_id(i + 1, is_forward=False), input_tensor_grad=input_obj_grad, - send_prior_fallback=self.stage_manager.stage % 2 == 0 and i > 0, + send_first=self.stage_manager.stage % 2 == 0, ) - return input_obj + return output_obj_grad, wait_handles - if self.stage_manager.stage % 2 == 0: - output_obj_grad = send_forward_recv_backward() - input_obj = send_backward_recv_forward() - else: - input_obj = send_backward_recv_forward() - output_obj_grad = send_forward_recv_backward() + input_obj, fwd_wait_handles = send_forward_recv_forward() + # Wait for upstream grad + _wait_p2p(bwd_wait_handles) + input_obj_grad = self.backward_step(optimizer, _input_obj, _output_obj, output_obj_grad) + # NOTE: It's documented by NCCL that running two concurrent communicators (batch_isend_irecv) + # risks deadlock (https://docs.nvidia.com/deeplearning/nccl/archives/nccl_2134/user-guide/docs/usage/communicators.html) + # however in practice this works fine, and Megatron does this too + # (https://github.com/microsoft/Megatron-DeepSpeed/blob/bcedecd1ff788d4d363f3365fd396053a08d65be/megatron/core/pipeline_parallel/schedules.py#L774) + # if deadlock, call _wait_p2p(fwd_wait_handles) here + output_obj_grad, bwd_wait_handles = send_backward_recv_backward() if num_microbatch_remaining == 0: model_chunk_id = self.get_model_chunk_id(0, is_forward=False) - output_obj_grad = self.recv_backward(model_chunk_id) + output_obj_grad, bwd_wait_handles = self.recv_backward(model_chunk_id) + # Run cooldown backward passes. for i in range(num_microbatch_remaining, num_microbatch): - last_iteration = i == num_microbatch - 1 + last_batch = i == num_microbatch - 1 model_chunk_id = self.get_model_chunk_id(i, is_forward=False) _input_obj = input_objs[model_chunk_id].pop(0) _output_obj = output_objs[model_chunk_id].pop(0) - # output_obj_grad = self.recv_backward(model_chunk_id) - input_obj_grad = self.backward_step(optimizer, _input_obj, _output_obj, output_obj_grad) - if not last_iteration: - output_obj_grad = self.send_backward_recv_backward( + # Wait for upstream grad + _wait_p2p(bwd_wait_handles) + # backward local grads + input_obj_grad = self.backward_step(optimizer, _input_obj, _output_obj, output_obj_grad) + if not last_batch: + output_obj_grad, bwd_wait_handles = self.send_backward_recv_backward( model_chunk_id_send=self.get_model_chunk_id(i, is_forward=False), model_chunk_id_recv=self.get_model_chunk_id(i + 1, is_forward=False), input_tensor_grad=input_obj_grad, - send_prior=self.stage_manager.stage % 2 == 0 and i > num_microbatch_remaining, + send_first=self.stage_manager.stage % 2 == 0 and i > num_microbatch_remaining, ) + assert (not self.overlap_p2p) or len(bwd_wait_handles) > 0 else: model_chunk_id = self.get_model_chunk_id(i, is_forward=False) - self.send_backward(model_chunk_id, input_obj_grad) + _ = self.send_backward(model_chunk_id, input_obj_grad) assert all(len(v) == 0 for v in input_objs) and all(len(v) == 0 for v in output_objs) diff --git a/colossalai/pipeline/schedule/one_f_one_b.py b/colossalai/pipeline/schedule/one_f_one_b.py index bfea8b67d..7f0d0e349 100644 --- a/colossalai/pipeline/schedule/one_f_one_b.py +++ b/colossalai/pipeline/schedule/one_f_one_b.py @@ -45,7 +45,8 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): num_microbatches is not None or microbatch_size is not None ), "Either num_microbatches or microbatch_size should be provided" - self.comm = PipelineP2PCommunication(stage_manager) + self.comm = PipelineP2PCommunication(stage_manager, overlap_p2p=False) + self.num_microbatches = num_microbatches self.microbatch_size = microbatch_size self.batch: Optional[Any] = None @@ -124,7 +125,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): Any: The input tensor or input tensor list. """ if not self.stage_manager.is_first_stage(): - input_tensor = self.comm.recv_forward(prev_rank, metadata_recv=self.tensor_metadata_recv) + input_tensor, _ = self.comm.recv_forward(prev_rank, metadata_recv=self.tensor_metadata_recv) if self.enable_metadata_cache and self.tensor_metadata_recv is None: self.tensor_metadata_recv = create_send_metadata(input_tensor) @@ -141,7 +142,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): Any: The input gradient tensor or gradient tensor list. """ if not self.stage_manager.is_last_stage(): - output_tensor_grad = self.comm.recv_backward(next_rank, metadata_recv=self.grad_metadata_recv) + output_tensor_grad, _ = self.comm.recv_backward(next_rank, metadata_recv=self.grad_metadata_recv) if self.enable_metadata_cache and self.grad_metadata_recv is None: self.grad_metadata_recv = create_send_metadata(output_tensor_grad) @@ -171,9 +172,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): self.comm.send_backward(input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata) self.send_grad_metadata = not self.enable_metadata_cache - def send_forward_recv_backward( - self, output_tensor: Any, next_rank: int = None, send_prior_fallback: Optional[bool] = None - ) -> Any: + def send_forward_recv_backward(self, output_tensor: Any, send_first: Optional[bool] = None) -> Any: """Sends the input tensor to the next stage and copy the gradient tensor from the next stage in pipeline. For 1F1B. @@ -183,13 +182,12 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): """ if not self.stage_manager.is_last_stage(): if not self.send_tensor_metadata and self.grad_metadata_recv is not None: - send_prior_fallback = None # must not fallback - output_tensor_grad = self.comm.send_forward_recv_backward( + send_first = None + output_tensor_grad, _ = self.comm.send_forward_recv_backward( output_tensor, - next_rank, send_metadata=self.send_tensor_metadata, metadata_recv=self.grad_metadata_recv, - send_prior_fallback=send_prior_fallback, + send_first=send_first, ) self.send_tensor_metadata = not self.enable_metadata_cache if self.enable_metadata_cache and self.grad_metadata_recv is None: @@ -197,9 +195,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): return output_tensor_grad - def send_backward_recv_forward( - self, input_tensor_grad: Any, prev_rank: int = None, send_prior_fallback: Optional[bool] = None - ) -> Any: + def send_backward_recv_forward(self, input_tensor_grad: Any, send_first: Optional[bool] = None) -> Any: """Sends the gradient tensor to the previous stage and copy the input tensor from the previous stage in pipeline. For 1F1B. @@ -209,13 +205,12 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): """ if not self.stage_manager.is_first_stage(): if not self.send_grad_metadata and self.tensor_metadata_recv is not None: - send_prior_fallback = None # must not fallback - input_tensor = self.comm.send_backward_recv_forward( + send_first = None # must not fallback + input_tensor, _ = self.comm.send_backward_recv_forward( input_tensor_grad, - prev_rank, send_metadata=self.send_grad_metadata, metadata_recv=self.tensor_metadata_recv, - send_prior_fallback=send_prior_fallback, + send_first=send_first, ) self.send_grad_metadata = not self.enable_metadata_cache if self.enable_metadata_cache and self.tensor_metadata_recv is None: @@ -381,9 +376,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): last_iteration = i == (num_microbatches_remaining - 1) output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs) - output_obj_grad = self.send_forward_recv_backward( - output_obj, send_prior_fallback=self.stage_manager.stage % 2 == 0 - ) + output_obj_grad = self.send_forward_recv_backward(output_obj, send_first=self.stage_manager.stage % 2 == 0) # Add input_obj and output_obj to end of list. input_objs.append(input_obj) output_objs.append(output_obj) @@ -398,7 +391,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): self.send_backward(input_obj_grad) else: input_obj = self.send_backward_recv_forward( - input_obj_grad, send_prior_fallback=self.stage_manager.stage % 2 == 0 + input_obj_grad, send_first=self.stage_manager.stage % 2 == 0 ) # Run cooldown backward passes. diff --git a/colossalai/pipeline/stage_manager.py b/colossalai/pipeline/stage_manager.py index b7cbd67ab..354f110f0 100644 --- a/colossalai/pipeline/stage_manager.py +++ b/colossalai/pipeline/stage_manager.py @@ -35,7 +35,7 @@ class PipelineStageManager: self.pipeline_axis = pipeline_axis self.prev_rank: Optional[Tuple[int, ...]] = None self.next_rank: Optional[Tuple[int, ...]] = None - self.p2p_groups: Dict[Tuple[int, int], ProcessGroup] = {} + self.p2p_groups: Dict[Tuple[int, ...], ProcessGroup] = {} if num_layers_per_stage is not None: assert len(num_layers_per_stage) == self.num_stages self.num_layers_per_stage = num_layers_per_stage @@ -48,30 +48,14 @@ class PipelineStageManager: # the next rank of the last rank is rank0 next_coord = coord[: self.pipeline_axis] + (coord[self.pipeline_axis] + 1,) + coord[self.pipeline_axis + 1 :] self.next_rank = self.pg_mesh.ravel(next_coord, self.pg_mesh.shape, mode="wrap") - - # init p2p process groups - stages = list(range(self.num_stages)) - for prev, cur in zip(stages[:-1], stages[1:]): - group = self.pg_mesh.get_group_along_axis(self.pipeline_axis, [prev, cur]) - if self.stage in [prev, cur]: - ranks_in_group = self.pg_mesh.get_ranks_in_group(group) - self.p2p_groups[tuple(ranks_in_group)] = group - self.is_interleave = enable_interleave # for interleaved pipeline parallel, each device is responsible for multiple chunk of layers self.num_model_chunks: int = num_model_chunks - if enable_interleave: - # use circle p2p communication - # add the process group of the first rank and the last rank - group = self.pg_mesh.get_group_along_axis(self.pipeline_axis, [stages[0], stages[-1]]) - if self.stage in [stages[0], stages[-1]]: - ranks_in_group = self.pg_mesh.get_ranks_in_group(group) - self.p2p_groups[tuple(ranks_in_group)] = group - - # for shardformer, hold stage indices of model - self.stage_indices: List[Tuple[int, int]] - # for shardformer, hold model chunk id - self.model_chunk_id: Optional[int] = None + # for shardformer, hold stage indices of model + self.stage_indices: List[Tuple[int, int]] + # for shardformer, hold model chunk id + self.model_chunk_id: Optional[int] = None + self.p2p_group = self.pg_mesh.get_group_along_axis(self.pipeline_axis) def get_stage_index( self, @@ -184,19 +168,12 @@ class PipelineStageManager: """ return self.next_rank - def get_p2p_process_group(self, first_rank: int, second_rank: int) -> ProcessGroup: + def get_p2p_process_group(self) -> ProcessGroup: """Get the p2p process group between two ranks. The order of the two ranks does not matter. - - Args: - first_rank (int): The first rank. - second_rank (int): The second rank. - Returns: ProcessGroup: P2P process group between the two ranks. """ - if first_rank > second_rank: - first_rank, second_rank = second_rank, first_rank - return self.p2p_groups[(first_rank, second_rank)] + return self.p2p_group def init_process_group_by_stages(self, stages: List[int]) -> ProcessGroup: """Get the process group of the given stages. diff --git a/colossalai/shardformer/layer/qkv_fused_linear.py b/colossalai/shardformer/layer/qkv_fused_linear.py index dc3634238..0f6595a7c 100644 --- a/colossalai/shardformer/layer/qkv_fused_linear.py +++ b/colossalai/shardformer/layer/qkv_fused_linear.py @@ -674,6 +674,8 @@ class FusedLinear1D_Col(ParallelModule): process_group (`Union[ProcessGroup, List[ProcessGroup]]`): The process group to be used for weight sharding and communication. n_fused (int): The number of layers to be fused. In common, Q,K,V are fused in one weight. """ + LazyInitContext.materialize(module) + # get the attributes in_features = module.in_features out_features = module.out_features diff --git a/colossalai/shardformer/modeling/t5.py b/colossalai/shardformer/modeling/t5.py index b35bb6b94..1b5c03ce4 100644 --- a/colossalai/shardformer/modeling/t5.py +++ b/colossalai/shardformer/modeling/t5.py @@ -8,8 +8,15 @@ from transformers.modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, Seq2SeqLMOutput, Seq2SeqModelOutput, + TokenClassifierOutput, +) +from transformers.models.t5.modeling_t5 import ( + T5EncoderModel, + T5ForConditionalGeneration, + T5ForTokenClassification, + T5Model, + T5Stack, ) -from transformers.models.t5.modeling_t5 import T5EncoderModel, T5ForConditionalGeneration, T5Model, T5Stack from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager @@ -582,6 +589,71 @@ class T5PipelineForwards: return outputs + @staticmethod + def t5_for_token_classification_forward( + self: T5ForTokenClassification, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + position_bias: Optional[torch.Tensor] = None, + encoder_decoder_position_bias: Optional[torch.Tensor] = None, + labels: Optional[torch.LongTensor] = None, + backward_tensor_keys: Optional[List[str]] = None, + stage_index: Optional[List[int]] = None, + decoder_starting_stage: Optional[int] = None, + ) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]: + r""" + This function is modified on the basis of transformers.models.t5.modeling_t5.T5ForTokenClassification.forward. + Please refer to original code of transformers for more details. + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = T5PipelineForwards.t5_stack_forward( + self.transformer.encoder, + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + position_bias=position_bias, + encoder_decoder_position_bias=encoder_decoder_position_bias, + stage_index=stage_index, + decoder_starting_stage=decoder_starting_stage, + ) + if stage_manager.is_last_stage(): + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + return outputs + def get_t5_flash_attention_forward(): from transformers.models.t5.modeling_t5 import T5Attention diff --git a/colossalai/shardformer/policies/auto_policy.py b/colossalai/shardformer/policies/auto_policy.py index 008dead6b..99b68aee2 100644 --- a/colossalai/shardformer/policies/auto_policy.py +++ b/colossalai/shardformer/policies/auto_policy.py @@ -68,6 +68,9 @@ _POLICY_LIST = { file_name="t5", class_name="T5ForConditionalGenerationPolicy" ), "transformers.models.t5.modeling_t5.T5EncoderModel": PolicyLocation(file_name="t5", class_name="T5EncoderPolicy"), + "transformers.models.t5.modeling_t5.T5ForTokenClassification": PolicyLocation( + file_name="t5", class_name="T5ForTokenClassificationPolicy" + ), # GPT2 "transformers.models.gpt2.modeling_gpt2.GPT2Model": PolicyLocation(file_name="gpt2", class_name="GPT2ModelPolicy"), "transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel": PolicyLocation( diff --git a/colossalai/shardformer/policies/t5.py b/colossalai/shardformer/policies/t5.py index 1298f0af3..0b594678c 100644 --- a/colossalai/shardformer/policies/t5.py +++ b/colossalai/shardformer/policies/t5.py @@ -31,7 +31,13 @@ from ..modeling.t5 import ( ) from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription -__all__ = ["distribute_t5_layers", "T5ModelPolicy", "T5ForConditionalGenerationPolicy", "T5EncoderPolicy"] +__all__ = [ + "distribute_t5_layers", + "T5ModelPolicy", + "T5ForConditionalGenerationPolicy", + "T5EncoderPolicy", + "T5ForTokenClassificationPolicy", +] class T5BasePolicy(Policy): @@ -312,9 +318,13 @@ class T5BasePolicy(Policy): assert self.pipeline_stage_manager is not None stage_manager = self.pipeline_stage_manager - model = self.model - encoder = self.model.encoder - decoder = getattr(self.model, "decoder", None) + if self.model.__class__.__name__ == "T5ForTokenClassification": + model = self.model.transformer + else: + model = self.model + + encoder = model.encoder + decoder = getattr(model, "decoder", None) num_encoder_layers = len(encoder.block) num_decoder_layers = len(decoder.block) if decoder else 0 @@ -353,7 +363,11 @@ class T5BasePolicy(Policy): raise ValueError("set_pipeline_forward method can only be called when pipeline parallel is enabled.") stage_manager = self.pipeline_stage_manager - encoder = self.model.encoder + if self.model.__class__.__name__ == "T5ForTokenClassification": + encoder = self.model.transformer.encoder + else: + encoder = self.model.encoder + decoder = getattr(self.model, "decoder", None) num_encoder_layers = len(encoder.block) @@ -542,3 +556,46 @@ class T5EncoderPolicy(T5BasePolicy): def get_shared_params(self) -> List[Dict[int, Tensor]]: return [] + + +class T5ForTokenClassificationPolicy(T5EncoderPolicy): + def module_policy(self): + from transformers.models.t5.modeling_t5 import T5ForTokenClassification + + policy = super().module_policy() + + if self.shard_config.enable_tensor_parallelism: + addon_module = { + T5ForTokenClassification: ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=DropoutForParallelInput, + ) + ] + ) + } + policy.update(addon_module) + if self.pipeline_stage_manager: + self.set_pipeline_forward( + model_cls=T5ForTokenClassification, + new_forward=T5PipelineForwards.t5_for_token_classification_forward, + policy=policy, + ) + + return policy + + def get_held_layers(self) -> List[nn.Module]: + """ + get pipeline layers for current stage + """ + held_layers = super().get_held_layers() + stage_manager = self.pipeline_stage_manager + if stage_manager.is_last_stage(ignore_chunk=True): + held_layers.append(self.model.dropout) + held_layers.append(self.model.classifier) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + # no shared params for sequence classification model + return [] diff --git a/colossalai/zero/gemini/chunk/chunk.py b/colossalai/zero/gemini/chunk/chunk.py index 18fbf8fc3..969df9621 100644 --- a/colossalai/zero/gemini/chunk/chunk.py +++ b/colossalai/zero/gemini/chunk/chunk.py @@ -403,9 +403,9 @@ class Chunk: self.shard_size, dtype=self.dtype, device=get_accelerator().get_current_device() ) - input_list = list(torch.chunk(self.cuda_global_chunk, chunks=self.pg_size, dim=0)) - self.grad_reduce_work = dist.reduce_scatter( - self.cuda_shard, input_list, group=self.torch_pg, async_op=async_op + assert self.cuda_global_chunk.is_contiguous() + self.grad_reduce_work = dist.reduce_scatter_tensor( + self.cuda_shard, self.cuda_global_chunk, group=self.torch_pg, async_op=async_op ) if self.extra_dp_group is not None: @@ -520,8 +520,10 @@ class Chunk: assert self.cuda_shard is not None alloc_storage(self.cuda_global_chunk) - gather_list = list(torch.chunk(input=self.cuda_global_chunk, chunks=self.pg_size, dim=0)) - work = dist.all_gather(gather_list, self.cuda_shard, self.torch_pg, async_op=async_op) + assert self.cuda_global_chunk.is_contiguous() + work = dist.all_gather_into_tensor( + self.cuda_global_chunk, self.cuda_shard, self.torch_pg, async_op=async_op + ) self.cuda_shard = None self.is_gathered = True diff --git a/colossalai/zero/gemini/chunk/manager.py b/colossalai/zero/gemini/chunk/manager.py index 3a5f0a5aa..d0e1755f4 100644 --- a/colossalai/zero/gemini/chunk/manager.py +++ b/colossalai/zero/gemini/chunk/manager.py @@ -133,12 +133,12 @@ class ChunkManager: self.__sub_accessed_chunk(chunk) self.__add_memory_usage(chunk.memory_usage) - def move_chunk(self, chunk: Chunk, device: torch.device, force_copy: bool = False) -> None: + def move_chunk(self, chunk: Chunk, device: torch.device, force_copy: bool = False, async_move=False) -> None: """Move the shard of the chunk to the target device.""" if not chunk.can_move or chunk.device_type == device.type: return self.__sub_memory_usage(chunk.memory_usage) - chunk.shard_move(device, force_copy) + chunk.shard_move(device, force_copy, non_blocking=async_move) self.__add_memory_usage(chunk.memory_usage) def trans_tensor_state(self, tensor: torch.Tensor, state: TensorState) -> None: diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index ebdde83b4..80b2c7961 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -387,6 +387,7 @@ class GeminiDDP(ModelWrapper): p: nn.Parameter, async_reduce_stream: Optional[torch.cuda.Stream] = None, ): + async_reduce_scatter = async_reduce_stream is not None setattr(p, "_gemini_reduced", True) empty_grad = torch.empty_like(grad) free_storage(empty_grad) @@ -426,7 +427,7 @@ class GeminiDDP(ModelWrapper): async_reduce_stream.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(async_reduce_stream): - reduced = chunk_manager.reduce_chunk(grad_chunk, async_op=(async_reduce_stream is not None)) + reduced = chunk_manager.reduce_chunk(grad_chunk, async_op=async_reduce_scatter) if reduced: grad_chunk.wait_async_reduce() if not chunk_manager.reuse_fp16_chunk: @@ -447,9 +448,13 @@ class GeminiDDP(ModelWrapper): # record l2 norm for gradient clipping. flag is bound to fp16 chunk if chunk.l2_norm_flag: grad_chunk.set_l2_norm() - chunk_manager.move_chunk(grad_chunk, grads_device[p], force_copy=True) + chunk_manager.move_chunk( + grad_chunk, grads_device[p], force_copy=True, async_move=async_reduce_scatter + ) if not (master_weights) or (enable_gradient_accumulation): - chunk_manager.move_chunk(chunk, grads_device[p], force_copy=True) + chunk_manager.move_chunk( + chunk, grads_device[p], force_copy=True, async_move=async_reduce_scatter + ) return empty_grad def zero_grad(self, set_to_none: bool = False) -> None: diff --git a/colossalai/zero/low_level/bookkeeping/tensor_bucket.py b/colossalai/zero/low_level/bookkeeping/tensor_bucket.py index 16ba8a6d6..5b09019b9 100644 --- a/colossalai/zero/low_level/bookkeeping/tensor_bucket.py +++ b/colossalai/zero/low_level/bookkeeping/tensor_bucket.py @@ -1,3 +1,7 @@ +from typing import Optional + +import torch +import torch.distributed as dist from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors @@ -6,6 +10,7 @@ class TensorBucket: self._max_size = size self._current_size = 0 self._bucket = [] + self._write_back_pairs = {} @property def max_size(self): @@ -21,7 +26,7 @@ class TensorBucket: def is_empty(self): return len(self._bucket) == 0 - def add_to_bucket(self, tensor, allow_oversize=False): + def add_to_bucket(self, tensor, allow_oversize=False, write_back_tensor: Optional[torch.Tensor] = None): tensor_size = tensor.numel() if not allow_oversize and self.will_exceed_max_size(tensor_size): @@ -30,6 +35,8 @@ class TensorBucket: self._bucket.append(tensor) self._current_size += tensor_size + write_back_tensor = write_back_tensor if write_back_tensor is not None else tensor + self._write_back_pairs[tensor] = write_back_tensor def will_exceed_max_size(self, tensor_size): expected_size = self._current_size + tensor_size @@ -40,12 +47,30 @@ class TensorBucket: def empty(self): self._bucket = [] - self._size = 0 + self._current_size = 0 + self._write_back_pairs = {} def flatten(self): return _flatten_dense_tensors(self._bucket) + def unflatten(self, flat_tensor): + return _unflatten_dense_tensors(flat_tensor, self._bucket) + def unflatten_and_copy(self, flat_tensor): - unflattened_tensor_list = _unflatten_dense_tensors(flat_tensor, self._bucket) + unflattened_tensor_list = self.unflatten(flat_tensor) for old, new in zip(self._bucket, unflattened_tensor_list): old.copy_(new) + + def all_gather(self, group=None): + flat = self.flatten() + buffers = [torch.empty_like(flat) for _ in range(dist.get_world_size(group))] + dist.all_gather(buffers, flat, group=group) + unflat_buffers = [self.unflatten(buffer) for buffer in buffers] + # transpose the list of list + unflat_buffers = list(map(list, zip(*unflat_buffers))) + for unflat_shards, tensor in zip(unflat_buffers, self._bucket): + write_back_tensor = self._write_back_pairs[tensor] + write_back_tensor.data.copy_( + _flatten_dense_tensors(unflat_shards)[: write_back_tensor.numel()].reshape_as(write_back_tensor) + ) + self.empty() diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 5f7f2a4e2..d19e0a002 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -23,7 +23,7 @@ from colossalai.logging import get_dist_logger from colossalai.tensor.moe_tensor.api import is_moe_tensor from ._utils import calculate_global_norm_from_list, flatten, has_inf_or_nan, release_param_grad, sync_tensor -from .bookkeeping import BucketStore, GradientStore, ParameterStore +from .bookkeeping import BucketStore, GradientStore, ParameterStore, TensorBucket class LowLevelZeroFP16MixedPrecisionMixin(FP16MixedPrecisionMixin): @@ -694,34 +694,33 @@ class LowLevelZeroOptimizer(OptimizerWrapper): for group_id in range(self.num_param_groups): release_param_grad(self._master_param_groups_of_current_rank[group_id]) + tensor_bucket = TensorBucket(self._bucket_store.reduce_bucket_size) + moe_tensor_bucket = TensorBucket(self._bucket_store.reduce_bucket_size) + # update working partition updated by the current rank device = get_accelerator().get_current_device() for group_id in range(self.num_param_groups): master_working_param = self.optim.param_groups[group_id]["params"] for idx, splited_param in enumerate(master_working_param): working_param = real_working_params[group_id][idx] + param_to_gather = splited_param.to(device).to(self._dtype) if self._bucket_store.moe_extra_dp_pg is not None and is_moe_tensor(working_param): - all_splited_param = [ - torch.zeros(splited_param.shape, device=device, dtype=self._dtype) - for _ in range(self._bucket_store.moe_extra_dp_pg_size) - ] - dist.all_gather( - all_splited_param, - splited_param.to(device).to(self._dtype), - group=self._bucket_store.moe_extra_dp_pg, - ) + try: + moe_tensor_bucket.add_to_bucket(param_to_gather, write_back_tensor=working_param) + except RuntimeError: + moe_tensor_bucket.all_gather(self._bucket_store.moe_extra_dp_pg) + moe_tensor_bucket.add_to_bucket(param_to_gather, write_back_tensor=working_param) else: - all_splited_param = [ - torch.zeros(splited_param.shape, device=device, dtype=self._dtype) - for _ in range(self._bucket_store.zero_world_size) - ] - dist.all_gather( - all_splited_param, - splited_param.to(device).to(self._dtype), - group=self._bucket_store.torch_pg, - ) - working_param.data.copy_(flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param)) + try: + tensor_bucket.add_to_bucket(param_to_gather, write_back_tensor=working_param) + except RuntimeError: + tensor_bucket.all_gather(self._bucket_store.moe_extra_dp_pg) + tensor_bucket.add_to_bucket(param_to_gather, write_back_tensor=working_param) self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id] + if not moe_tensor_bucket.is_empty(): + moe_tensor_bucket.all_gather(self._bucket_store.moe_extra_dp_pg) + if not tensor_bucket.is_empty(): + tensor_bucket.all_gather(self._bucket_store.torch_pg) def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> float: r""" diff --git a/docs/README-zh-Hans.md b/docs/README-zh-Hans.md index dc97b461a..ef121d348 100644 --- a/docs/README-zh-Hans.md +++ b/docs/README-zh-Hans.md @@ -9,6 +9,7 @@ 文档 | 例程 | 论坛 | + 潞晨云 | 博客 [![GitHub Repo stars](https://img.shields.io/github/stars/hpcaitech/ColossalAI?style=social)](https://github.com/hpcaitech/ColossalAI/stargazers) @@ -127,6 +128,8 @@ Colossal-AI 为您提供了一系列并行组件。我们的目标是让您的 [[博客]](https://hpc-ai.com/blog/open-sora-from-hpc-ai-tech-team-continues-open-source-generate-any-16-second-720p-hd-video-with-one-click-model-weights-ready-to-use) [[模型权重]](https://github.com/hpcaitech/Open-Sora?tab=readme-ov-file#model-weights) [[演示样例]](https://github.com/hpcaitech/Open-Sora?tab=readme-ov-file#-latest-demo) +[[潞晨云]](https://cloud.luchentech.com/) +[[OpenSora镜像]](https://cloud.luchentech.com/doc/docs/image/open-sora/)
@@ -135,6 +138,8 @@ Colossal-AI 为您提供了一系列并行组件。我们的目标是让您的
### Colossal-LLaMA-2 +[[潞晨云]](https://cloud.luchentech.com/) +[[LLaMA3 镜像]](https://cloud.luchentech.com/doc/docs/image/llama) - 7B:千元预算半天训练,效果媲美主流大模型,开源可商用中文LLaMA-2 [[代码]](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Colossal-LLaMA-2) @@ -265,7 +270,9 @@ Colossal-AI 为您提供了一系列并行组件。我们的目标是让您的

- 700亿参数LLaMA3训练加速18% -[[code]](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/llama) +[[代码]](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/llama) +[[潞晨云]](https://cloud.luchentech.com/) +[[LLaMA3 镜像]](https://cloud.luchentech.com/doc/docs/image/llama) ### LLaMA2

@@ -378,6 +385,8 @@ Colossal-AI 为您提供了一系列并行组件。我们的目标是让您的 - AI大模型推理速度部分接近翻倍,与vLLM的离线推理性能相比 [[代码]](https://github.com/hpcaitech/ColossalAI/tree/main/colossalai/inference) [[博客]](https://hpc-ai.com/blog/colossal-inference) +[[潞晨云]](https://cloud.luchentech.com/) +[[LLaMA3 镜像]](https://cloud.luchentech.com/doc/docs/image/llama) ### Grok-1

diff --git a/examples/README.md b/examples/README.md index b822fb8ff..045632fc3 100644 --- a/examples/README.md +++ b/examples/README.md @@ -1,4 +1,12 @@ # Colossal-AI Examples +

+ +

+ GPU Cloud Playground | + Playground Document +

+ +
## Table of Contents diff --git a/examples/language/llama/benchmark.py b/examples/language/llama/benchmark.py index f6c975305..8a35db1f7 100644 --- a/examples/language/llama/benchmark.py +++ b/examples/language/llama/benchmark.py @@ -1,9 +1,11 @@ import argparse import resource import time +import warnings from contextlib import nullcontext import torch +import torch.distributed as dist from data_utils import RandomDataset from model_utils import format_numel_str, get_model_numel from performance_evaluator import PerformanceEvaluator, get_profile_context @@ -21,11 +23,19 @@ from colossalai.lazy import LazyInitContext from colossalai.nn.optimizer import HybridAdam from colossalai.shardformer import PipelineGradientCheckpointConfig +warnings.filterwarnings("ignore") # ============================== # Constants # ============================== MODEL_CONFIGS = { + "100m": LlamaConfig( + max_position_embeddings=4096, + num_hidden_layers=4, + num_attention_heads=32, + intermediate_size=2048, + hidden_size=1024, + ), "7b": LlamaConfig(max_position_embeddings=4096), "13b": LlamaConfig( hidden_size=5120, @@ -58,6 +68,9 @@ def main(): default="gemini", help="Choose which plugin to use", ) + parser.add_argument( + "--overlap", action="store_true", help="Overlap communication with computation in Pipeline Parallel." + ) parser.add_argument("-b", "--batch_size", type=int, default=2, help="Batch size") parser.add_argument("-s", "--num_steps", type=int, default=5, help="Number of steps to run") parser.add_argument("-i", "--ignore_steps", type=int, default=2, help="Number of steps to ignore") @@ -78,11 +91,13 @@ def main(): parser.add_argument("--mbs", type=int, default=1, help="Micro batch size of pipeline parallel") parser.add_argument("--zero", type=int, default=0, help="Zero Stage when hybrid plugin is enabled") parser.add_argument("--custom-ckpt", action="store_true", help="Customize checkpoint", default=False) - parser.add_argument("--profile", action="store_true", help="Enable profiling", default=False) - parser.add_argument( - "--disable-async-reduce", action="store_true", help="Disable the asynchronous reduce operation", default=False - ) + + parser.add_argument("--pp_style", default="1f1b", choices=["1f1b", "interleaved"]) + parser.add_argument("--n_chunks", default=1, help="number of model chunks", type=eval) + parser.add_argument("--profile", action="store_true", help="Profile the code", default=False) + parser.add_argument("--disable-async-reduce", action="store_true", help="Disable the asynchronous reduce operation") parser.add_argument("--prefetch_num", type=int, default=0, help="chunk prefetch max number") + parser.add_argument("--no_cache", action="store_true") args = parser.parse_args() colossalai.launch_from_torch() @@ -98,6 +113,7 @@ def main(): num_ckpt_layers_per_stage=[19, 19, 19, 13], ), "num_layers_per_stage": [19, 20, 20, 21], + "pp_style": "interleaved", } if args.custom_ckpt else {} @@ -174,6 +190,8 @@ def main(): plugin = HybridParallelPlugin( tp_size=args.tp, pp_size=args.pp, + pp_style=args.pp_style, + num_model_chunks=args.n_chunks, zero_stage=args.zero, sp_size=args.sp, enable_sequence_parallelism=args.sp > 1, @@ -182,12 +200,16 @@ def main(): microbatch_size=args.mbs, precision="bf16", dp_outside=False, + overlap_p2p=args.overlap, + enable_metadata_cache=not args.no_cache, **hybrid_kwargs, ) elif args.plugin == "3d_cpu": plugin = HybridParallelPlugin( tp_size=args.tp, pp_size=args.pp, + pp_style=args.pp_style, + num_model_chunks=args.n_chunks, zero_stage=args.zero, cpu_offload=True, enable_fused_normalization=torch.cuda.is_available(), @@ -195,6 +217,7 @@ def main(): microbatch_size=args.mbs, initial_scale=2**8, precision="bf16", + overlap_p2p=args.overlap, ) else: raise ValueError(f"Unknown plugin {args.plugin}") @@ -210,10 +233,11 @@ def main(): config = MODEL_CONFIGS[args.config] else: config = AutoConfig.from_pretrained(args.config, trust_remote_code=True) + torch.cuda.manual_seed(42) dataset = RandomDataset( num_samples=args.batch_size * args.num_steps * dp_size, max_length=args.max_length, vocab_size=config.vocab_size ) - dataloader = plugin.prepare_dataloader(dataset, batch_size=args.batch_size, shuffle=True, drop_last=True) + dataloader = plugin.prepare_dataloader(dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, seed=42) # ============================== # Initialize Model and Optimizer @@ -229,8 +253,13 @@ def main(): init_kwargs["empty_init"] = False with init_ctx: - model = AutoModelForCausalLM.from_config(config, trust_remote_code=True, **init_kwargs) - + model = AutoModelForCausalLM.from_config( + config, + trust_remote_code=True, + **init_kwargs, + attn_implementation="flash_attention_2", + torch_dtype=torch.bfloat16, + ) if args.grad_checkpoint: model.gradient_checkpointing_enable() if config.model_type == "chatglm": @@ -251,6 +280,7 @@ def main(): optimizer = HybridAdam(model.parameters()) torch.set_default_dtype(torch.bfloat16) model, optimizer, _, dataloader, _ = booster.boost(model, optimizer, dataloader=dataloader) + torch.set_default_dtype(torch.float) coordinator.print_on_master( f"Booster init max CUDA memory: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB" @@ -261,7 +291,7 @@ def main(): with get_profile_context( args.profile, - 1, + args.ignore_steps, len(dataloader) - 1, save_dir=f"profile/{time.strftime('%H:%M', time.localtime())}-{args.plugin}-llama-{args.config}", ) as prof: @@ -269,15 +299,19 @@ def main(): data_iter = iter(dataloader) for step in tqdm(range(len(dataloader)), desc="Step", disable=not coordinator.is_master()): performance_evaluator.on_step_start(step) - booster.execute_pipeline( + outputs = booster.execute_pipeline( data_iter, model, criterion=lambda outputs, inputs: outputs[0], optimizer=optimizer, - return_loss=False, + return_loss=True, ) + loss = outputs["loss"] + if dist.get_rank() == dist.get_world_size() - 1: + print(f"Step {step} loss: {loss}") optimizer.step() optimizer.zero_grad() + performance_evaluator.on_step_end(input_ids=torch.empty(args.batch_size, args.max_length)) prof.step() else: @@ -288,6 +322,7 @@ def main(): booster.backward(loss, optimizer) optimizer.step() optimizer.zero_grad() + performance_evaluator.on_step_end(**batch) prof.step() diff --git a/tests/kit/model_zoo/transformers/t5.py b/tests/kit/model_zoo/transformers/t5.py index 2ccfb0356..f6ccb297e 100644 --- a/tests/kit/model_zoo/transformers/t5.py +++ b/tests/kit/model_zoo/transformers/t5.py @@ -40,6 +40,14 @@ def data_gen_for_t5_model(): return data +def data_gen_for_token_classification(): + # token classification data gen + # `labels` is the type not the token id for token classification, 0 or 1 + data = data_gen_for_encoder_only() + data["labels"] = torch.tensor([[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], dtype=torch.int64) + return data + + # output transform function output_transform_fn = lambda x: x @@ -47,6 +55,7 @@ output_transform_fn = lambda x: x loss_fn_for_t5_model = lambda x: x["last_hidden_state"].mean() loss_fn_for_encoder_only = lambda x: x["last_hidden_state"].mean() loss_fn_for_conditional_generation = lambda x: x["loss"] +loss_fn_for_token_classification = lambda x: x["loss"] # define model config config = transformers.T5Config(d_model=128, num_layers=2, dropout_rate=0, decoder_start_token_id=0) @@ -79,3 +88,11 @@ model_zoo.register( loss_fn=loss_fn_for_encoder_only, model_attribute=ModelAttribute(has_control_flow=True), ) +model_zoo.register( + name="transformers_t5_for_token_classification", + model_fn=lambda: transformers.T5ForTokenClassification(config), + data_gen_fn=data_gen_for_token_classification, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_token_classification, + model_attribute=ModelAttribute(has_control_flow=True), +) diff --git a/tests/test_pipeline/test_p2p_communication.py b/tests/test_pipeline/test_p2p_communication.py index 48a8d12e0..30b557f5e 100644 --- a/tests/test_pipeline/test_p2p_communication.py +++ b/tests/test_pipeline/test_p2p_communication.py @@ -15,8 +15,7 @@ WORLD_SIZE = 2 def check_p2p_communication(): pg_mesh = ProcessGroupMesh(WORLD_SIZE) stage_manager = PipelineStageManager(pg_mesh, 0) - p2p = PipelineP2PCommunication(stage_manager) - + p2p = PipelineP2PCommunication(stage_manager, overlap_p2p=False) rank = dist.get_rank() tensor = torch.ones(1, device=get_accelerator().get_current_device()) @@ -31,41 +30,40 @@ def check_p2p_communication(): for obj in data: p2p.send_forward(obj) for i in range(len(data)): - recv_obj = p2p.send_forward_recv_backward(data[i], send_prior_fallback=False) + recv_obj, _ = p2p.send_forward_recv_backward(data[i], send_first=False) assert recv_obj == data[-(i + 1)] elif rank == 1: for obj in data: - recv_obj = p2p.recv_forward() + recv_obj, _ = p2p.recv_forward() assert recv_obj == obj for i in range(len(data)): p2p.send_backward(data[-(i + 1)]) - recv_obj = p2p.recv_forward() + recv_obj, _ = p2p.recv_forward() assert recv_obj == data[i] if rank == 1: for obj in data: p2p.send_backward(obj) for i in range(len(data)): - recv_obj = p2p.send_backward_recv_forward(data[i], send_prior_fallback=True) + recv_obj, _ = p2p.send_backward_recv_forward(data[i], send_first=True) assert recv_obj == data[-(i + 1)] elif rank == 0: for obj in data: - recv_obj = p2p.recv_backward() + recv_obj, _ = p2p.recv_backward() assert recv_obj == obj for i in range(len(data)): - recv_obj = p2p.recv_backward() - p2p.send_forward(data[-(i + 1)]) + recv_obj, _ = p2p.send_forward_recv_backward(data[-(i + 1)], send_first=False) assert recv_obj == data[i] if rank == 0: - recv_obj = p2p.send_forward_recv_backward( + recv_obj, _ = p2p.send_forward_recv_backward( tensor, send_metadata=False, metadata_recv=create_send_metadata(tensor), ) assert recv_obj == tensor elif rank == 1: - recv_obj = p2p.recv_forward(metadata_recv=create_send_metadata(tensor)) + recv_obj, _ = p2p.recv_forward(metadata_recv=create_send_metadata(tensor)) assert recv_obj == tensor p2p.send_backward(tensor, send_metadata=False) diff --git a/tests/test_pipeline/test_stage_manager.py b/tests/test_pipeline/test_stage_manager.py index 5146a86c8..a3793013b 100644 --- a/tests/test_pipeline/test_stage_manager.py +++ b/tests/test_pipeline/test_stage_manager.py @@ -52,7 +52,7 @@ def check_stage_manager(): # check p2p groups for prev, cur in zip(ranks_in_group[:-1], ranks_in_group[1:]): if rank in [prev, cur]: - group = stage_manager.get_p2p_process_group(prev, cur) + group = stage_manager.get_p2p_process_group() dist.barrier(group=group) # check stage groups diff --git a/tests/test_shardformer/test_model/test_shard_t5.py b/tests/test_shardformer/test_model/test_shard_t5.py index 521dc9130..6cdf5bf41 100644 --- a/tests/test_shardformer/test_model/test_shard_t5.py +++ b/tests/test_shardformer/test_model/test_shard_t5.py @@ -41,14 +41,17 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, t5 = unwrap_model(org_model) sharded_t5 = unwrap_model(sharded_model) - row_layer_for_check = ["shared", "encoder.block[0].layer[0].SelfAttention.q"] + if t5.__class__.__name__ == "T5ForTokenClassification": + row_layer_for_check = ["transformer.shared", "transformer.encoder.block[0].layer[0].SelfAttention.q"] + else: + row_layer_for_check = ["shared", "encoder.block[0].layer[0].SelfAttention.q"] # Save gradient tensors for comparison between the original model and the sharded model before optimizer step. grads_to_check = {} if test_config["precision"] == "fp32": atol, rtol = 1e-5, 1e-3 else: - atol, rtol = 5e-3, 5e-3 + atol, rtol = 5e-2, 5e-2 if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: row_layer_grads = get_grad_tensors_for_check( t5, sharded_t5, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0 @@ -66,7 +69,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, else: atol, rtol = 5e-3, 5e-3 - if org_model.__class__.__name__ != "T5ForConditionalGeneration": + if org_model.__class__.__name__ not in ["T5ForConditionalGeneration", "T5ForTokenClassification"]: check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) @@ -157,7 +160,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ) @clear_cache_before_run() def run_t5_test(test_config): - sub_model_zoo = model_zoo.get_sub_registry("transformers_t5") + sub_model_zoo = model_zoo.get_sub_registry(["transformers_t5_for_token_classification"]) for name, ( model_fn, @@ -167,7 +170,10 @@ def run_t5_test(test_config): _, ) in sub_model_zoo.items(): # skip 4-stage pp test for t5_encoder - if test_config["pp_size"] > 2 and name == "transformers_t5_encoder_model": + if test_config["pp_size"] > 2 and name in [ + "transformers_t5_encoder_model", + "transformers_t5_for_token_classification", + ]: continue check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) diff --git a/version.txt b/version.txt index 940ac09aa..1d0ba9ea1 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.3.9 +0.4.0