mirror of https://github.com/hpcaitech/ColossalAI
Merge branch 'main' of https://github.com/hpcaitech/ColossalAI into rlhf_SimPO
commit
e7527762a1
10
README.md
10
README.md
|
@ -9,6 +9,7 @@
|
|||
<a href="https://www.colossalai.org/"> Documentation </a> |
|
||||
<a href="https://github.com/hpcaitech/ColossalAI/tree/main/examples"> Examples </a> |
|
||||
<a href="https://github.com/hpcaitech/ColossalAI/discussions"> Forum </a> |
|
||||
<a href="https://cloud.luchentech.com/">GPU Cloud Playground </a> |
|
||||
<a href="https://hpc-ai.com/blog"> Blog </a></h3>
|
||||
|
||||
[](https://github.com/hpcaitech/ColossalAI/stargazers)
|
||||
|
@ -132,6 +133,8 @@ distributed training and inference in a few lines.
|
|||
[[blog]](https://hpc-ai.com/blog/open-sora-from-hpc-ai-tech-team-continues-open-source-generate-any-16-second-720p-hd-video-with-one-click-model-weights-ready-to-use)
|
||||
[[Model weights]](https://github.com/hpcaitech/Open-Sora?tab=readme-ov-file#model-weights)
|
||||
[[Demo]](https://github.com/hpcaitech/Open-Sora?tab=readme-ov-file#-latest-demo)
|
||||
[[GPU Cloud Playground]](https://cloud.luchentech.com/)
|
||||
[[OpenSora Image]](https://cloud.luchentech.com/doc/docs/image/open-sora/)
|
||||
|
||||
<div align="center">
|
||||
<a href="https://youtu.be/ilMQpU71ddI?si=J4JSPzZ03ycYmlki">
|
||||
|
@ -143,6 +146,9 @@ distributed training and inference in a few lines.
|
|||
|
||||
### Colossal-LLaMA-2
|
||||
|
||||
[[GPU Cloud Playground]](https://cloud.luchentech.com/)
|
||||
[[LLaMA3 Image]](https://cloud.luchentech.com/doc/docs/image/llama)
|
||||
|
||||
- 7B: One half-day of training using a few hundred dollars yields similar results to mainstream large models, open-source and commercial-free domain-specific LLM solution.
|
||||
[[code]](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Colossal-LLaMA-2)
|
||||
[[blog]](https://www.hpc-ai.tech/blog/one-half-day-of-training-using-a-few-hundred-dollars-yields-similar-results-to-mainstream-large-models-open-source-and-commercial-free-domain-specific-llm-solution)
|
||||
|
@ -275,6 +281,8 @@ Acceleration of [AlphaFold Protein Structure](https://alphafold.ebi.ac.uk/)
|
|||
|
||||
- 70 billion parameter LLaMA3 model training accelerated by 18%
|
||||
[[code]](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/llama)
|
||||
[[GPU Cloud Playground]](https://cloud.luchentech.com/)
|
||||
[[LLaMA3 Image]](https://cloud.luchentech.com/doc/docs/image/llama)
|
||||
|
||||
### LLaMA2
|
||||
<p align="center">
|
||||
|
@ -385,6 +393,8 @@ Please visit our [documentation](https://www.colossalai.org/) and [examples](htt
|
|||
- Large AI models inference speed doubled, compared to the offline inference performance of vLLM in some cases.
|
||||
[[code]](https://github.com/hpcaitech/ColossalAI/tree/main/colossalai/inference)
|
||||
[[blog]](https://hpc-ai.com/blog/colossal-inference)
|
||||
[[GPU Cloud Playground]](https://cloud.luchentech.com/)
|
||||
[[LLaMA3 Image]](https://cloud.luchentech.com/doc/docs/image/llama)
|
||||
|
||||
### Grok-1
|
||||
<p id="Grok-1" align="center">
|
||||
|
|
|
@ -2,6 +2,12 @@
|
|||
<h1>
|
||||
Colossal-LLaMA
|
||||
</h1>
|
||||
|
||||
<h3>
|
||||
<a href="https://cloud.luchentech.com/">GPU Cloud Playground </a> </a> |
|
||||
<a href="https://cloud.luchentech.com/doc/docs/image/llama"> LLaMA3 Image </a>
|
||||
</h3>
|
||||
|
||||
</div>
|
||||
|
||||
## Table of Contents
|
||||
|
|
|
@ -2,6 +2,12 @@
|
|||
<h1>
|
||||
<img src="https://github.com/hpcaitech/public_assets/blob/main/applications/colossal-llama-2/colossaleval.jpg?raw=true" width=800/>
|
||||
</h1>
|
||||
|
||||
<h3>
|
||||
<a href="https://cloud.luchentech.com/">GPU Cloud Playground </a> </a> |
|
||||
<a href="https://cloud.luchentech.com/doc/docs/image/colossal-eval"> Colossal-Eval Image </a>
|
||||
</h3>
|
||||
|
||||
</div>
|
||||
|
||||
## Table of Contents
|
||||
|
|
|
@ -2,6 +2,15 @@
|
|||
|
||||
This directory contains the applications that are powered by Colossal-AI.
|
||||
|
||||
<div align="center">
|
||||
|
||||
<h3>
|
||||
<a href="https://cloud.luchentech.com/">GPU Cloud Playground </a> </a> |
|
||||
<a href="https://cloud.luchentech.com/doc/docs/intro"> Playground Document </a>
|
||||
</h3>
|
||||
|
||||
</div>
|
||||
|
||||
The list of applications include:
|
||||
|
||||
- [X] [Open-Sora](https://github.com/hpcaitech/Open-Sora): Revealing Complete Model Parameters, Training Details, and Everything for Sora-like Video Generation Models
|
||||
|
|
|
@ -369,9 +369,9 @@ class GeminiPlugin(DPPluginBase):
|
|||
assert precision in SUPPORTED_PRECISION, f"precision {precision} is not supported"
|
||||
if get_accelerator().name == "npu":
|
||||
assert placement_policy == "static", "NPU only supports static placement policy"
|
||||
if placement_policy == "auto" and enable_async_reduce:
|
||||
if enable_async_reduce and not pin_memory:
|
||||
logging.warning(
|
||||
f"enable_async_reduce requires pin_memory to achieve best performance, which is not implicitly set."
|
||||
f"enable_async_reduce sets pin_memory=True to achieve best performance, which is not implicitly set."
|
||||
)
|
||||
pin_memory = True
|
||||
self.gemini_config = dict(
|
||||
|
|
|
@ -946,7 +946,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
gradient_checkpoint_config (GradientCheckpointConfig, optional): Configuration for gradient checkpointing. Defaults to None.
|
||||
enable_metadata_cache (bool, optional): Whether to enable metadata cache for pipeline parallelism. Defaults to True.
|
||||
make_vocab_size_divisible_by (int, optional): it's used when padding the vocabulary size, to make it choose an faster kenel. Default to 64.
|
||||
|
||||
overlap_p2p (bool, optional): Whether to overlap the p2p communication in pipeline parallelism
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -992,6 +992,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
enable_metadata_cache: bool = True,
|
||||
make_vocab_size_divisible_by: int = 64,
|
||||
dp_outside: bool = True,
|
||||
overlap_p2p: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert (
|
||||
|
@ -1062,7 +1063,9 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
assert (
|
||||
num_microbatches is not None or microbatch_size is not None
|
||||
), "num_microbatches or microbatch_size must be specified when using pipeline parallelism"
|
||||
assert self.zero_stage <= 1, "zero stage must be 0 or 1 when using pipeline parallelism"
|
||||
assert (
|
||||
self.zero_stage <= 1
|
||||
), "To avoid prohibitive gradient synchronization costs, zero stage must be 0 or 1 when using pipeline parallelism"
|
||||
self.stage_manager = PipelineStageManager(
|
||||
self.pg_mesh,
|
||||
pipeline_axis=self.pp_axis,
|
||||
|
@ -1079,6 +1082,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
num_microbatch=num_microbatches,
|
||||
microbatch_size=microbatch_size,
|
||||
enable_metadata_cache=enable_metadata_cache,
|
||||
overlap_p2p=overlap_p2p,
|
||||
)
|
||||
elif pp_style == "1f1b":
|
||||
self.schedule = OneForwardOneBackwardSchedule(
|
||||
|
|
|
@ -134,7 +134,7 @@ class ProcessGroupMesh:
|
|||
"""
|
||||
|
||||
assert mode in ["raise", "wrap", "clip"]
|
||||
return np.ravel_multi_index(coord, shape, mode)
|
||||
return int(np.ravel_multi_index(coord, shape, mode))
|
||||
|
||||
def get_group(self, ranks_in_group: List[int], backend: Optional[str] = None) -> ProcessGroup:
|
||||
"""Get the process group with the given ranks. It the process group doesn't exist, it will be created.
|
||||
|
@ -182,7 +182,7 @@ class ProcessGroupMesh:
|
|||
axis = [
|
||||
axis,
|
||||
]
|
||||
assert isinstance(indices_at_axis[0], int)
|
||||
assert isinstance(indices_at_axis[0], int), f"Expected int, but got {type(indices_at_axis[0])}."
|
||||
indices_at_axis = [
|
||||
indices_at_axis,
|
||||
]
|
||||
|
|
|
@ -24,8 +24,9 @@ from colossalai.inference.modeling.policy import model_policy_map
|
|||
from colossalai.inference.sampler import search_tokens
|
||||
from colossalai.inference.spec import Drafter, GlideInput
|
||||
from colossalai.inference.struct import Sequence
|
||||
from colossalai.inference.utils import get_model_size
|
||||
from colossalai.inference.utils import get_model_size, has_index_file
|
||||
from colossalai.interface import ModelWrapper
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||
|
@ -122,16 +123,24 @@ class InferenceEngine:
|
|||
model_inference_config: the configuration for modeling initialization when inference.
|
||||
model_shard_infer_config (ModelShardInferenceConfig): the configuration for init of module when inference.
|
||||
"""
|
||||
|
||||
pretrained_path = None
|
||||
if isinstance(model_or_path, str):
|
||||
import colossalai.interface.pretrained as pretrained_utils
|
||||
|
||||
try:
|
||||
hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True)
|
||||
hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True, torch_dtype=self.dtype)
|
||||
arch = getattr(hf_config, "architectures")[0]
|
||||
if arch in _supported_models.keys():
|
||||
# NOTE(lry89757) Currently we load the model using transformers-api,
|
||||
# but we will use lazy tensor and checkpoint io to accelerate
|
||||
# the model load process in the future.
|
||||
model = _supported_models[arch].from_pretrained(model_or_path, trust_remote_code=True)
|
||||
if arch is "BaichuanForCausalLM":
|
||||
self.logger.warning(
|
||||
"Attention ! We use lazy init by default, which could be faster for model loading. For baichuan model, the output maybe have a slight difference with transformers"
|
||||
)
|
||||
ctx = LazyInitContext(default_device="cuda")
|
||||
with ctx:
|
||||
model = _supported_models[arch].from_pretrained(
|
||||
model_or_path, trust_remote_code=True, torch_dtype=self.dtype
|
||||
)
|
||||
pretrained_path = pretrained_utils.get_pretrained_path(model)
|
||||
else:
|
||||
# TODO(char-1ee): if the model not supported, use transformers APIs to load and generate
|
||||
raise ValueError(f"Model {arch} is not supported.")
|
||||
|
@ -189,14 +198,13 @@ class InferenceEngine:
|
|||
f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}"
|
||||
)
|
||||
|
||||
# NOTE(lry89757) Deprecated currently, will reused when introduce lazy tensor
|
||||
# if isinstance(model_or_path, str) and not isinstance(casuallm, AutoModelForCausalLM):
|
||||
# from colossalai.inference.core.plugin import InferCheckpoint_io
|
||||
if pretrained_path:
|
||||
from colossalai.inference.core.plugin import InferCheckpoint_io
|
||||
|
||||
# cpt_io = InferCheckpoint_io()
|
||||
# if_has_index_file, model_index_file = has_index_file(model_or_path)
|
||||
# assert if_has_index_file, "the model path is invalid"
|
||||
# cpt_io.load_model(self.model, model_index_file)
|
||||
cpt_io = InferCheckpoint_io()
|
||||
if_has_index_file, model_index_file = has_index_file(pretrained_path)
|
||||
assert if_has_index_file, "the model path is invalid"
|
||||
cpt_io.load_model(self.model, model_index_file)
|
||||
|
||||
free_gpu_memory, _ = torch.cuda.mem_get_info()
|
||||
peak_memory = init_gpu_memory - free_gpu_memory
|
||||
|
|
|
@ -73,7 +73,9 @@ class RPCInferenceEngine(InferenceEngine):
|
|||
|
||||
try:
|
||||
if isinstance(model_or_path, str):
|
||||
self.model_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True)
|
||||
self.model_config = AutoConfig.from_pretrained(
|
||||
model_or_path, trust_remote_code=True, torch_dtype=self.dtype
|
||||
)
|
||||
elif isinstance(model_or_path, nn.Module):
|
||||
self.logger.error(
|
||||
f"An exception occurred during loading model Config: For {__class__.__name__}, we don't support param like nn.Module currently\n"
|
||||
|
|
|
@ -18,8 +18,9 @@ from colossalai.inference.modeling.policy import (
|
|||
model_policy_map,
|
||||
)
|
||||
from colossalai.inference.sampler import search_tokens
|
||||
from colossalai.inference.utils import get_model_size
|
||||
from colossalai.inference.utils import get_model_size, has_index_file
|
||||
from colossalai.interface import ModelWrapper
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||
|
@ -178,20 +179,23 @@ class rpcWorkerService(rpyc.Service):
|
|||
model_policy (Policy): the policy to replace the model
|
||||
"""
|
||||
|
||||
pretrained_path = None
|
||||
if isinstance(model_or_path, str):
|
||||
# is_local = os.path.isdir(model_or_path)
|
||||
import colossalai.interface.pretrained as pretrained_utils
|
||||
|
||||
try:
|
||||
hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True)
|
||||
hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True, torch_dtype=self.dtype)
|
||||
arch = getattr(hf_config, "architectures")[0]
|
||||
# NOTE(lry89757) Currently we load the model using transformers-api,
|
||||
# but we will use lazy tensor and checkpoint io to accelerate
|
||||
# the model load process in the future.
|
||||
model = _SUPPORTED_MODELS[arch].from_pretrained(model_or_path, trust_remote_code=True)
|
||||
# if is_local:
|
||||
# model = _SUPPORTED_MODELS[arch](hf_config)
|
||||
# else:
|
||||
# # load the real checkpoint
|
||||
# model = _SUPPORTED_MODELS[arch].from_pretrained(model_or_path, trust_remote_code=True)
|
||||
if arch is "BaichuanForCausalLM":
|
||||
self.logger.warning(
|
||||
"Attention ! We use lazy init by default, which could be faster for model loading. For baichuan model, the output maybe have a slight difference with transformers"
|
||||
)
|
||||
ctx = LazyInitContext(default_device="cuda")
|
||||
with ctx:
|
||||
model = _SUPPORTED_MODELS[arch].from_pretrained(
|
||||
model_or_path, trust_remote_code=True, torch_dtype=self.dtype
|
||||
)
|
||||
pretrained_path = pretrained_utils.get_pretrained_path(model)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"An exception occurred during loading model: {e}, model should be loaded by transformers\n"
|
||||
|
@ -240,14 +244,13 @@ class rpcWorkerService(rpyc.Service):
|
|||
f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}"
|
||||
)
|
||||
|
||||
# NOTE(lry89757) Deprecated currently, will reused when introduce lazy tensor
|
||||
# if isinstance(model_or_path, str) and is_local:
|
||||
# from colossalai.inference.core.plugin import InferCheckpoint_io
|
||||
if pretrained_path:
|
||||
from colossalai.inference.core.plugin import InferCheckpoint_io
|
||||
|
||||
# cpt_io = InferCheckpoint_io()
|
||||
# if_has_index_file, model_index_file = has_index_file(model_or_path)
|
||||
# assert if_has_index_file, "the model path is invalid"
|
||||
# cpt_io.load_model(self.model, model_index_file)
|
||||
cpt_io = InferCheckpoint_io()
|
||||
if_has_index_file, model_index_file = has_index_file(pretrained_path)
|
||||
assert if_has_index_file, "the model path is invalid"
|
||||
cpt_io.load_model(self.model, model_index_file)
|
||||
|
||||
free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
|
||||
peak_memory = init_gpu_memory - free_gpu_memory
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
from typing import List, Union
|
||||
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.shardformer.layer import Linear1D_Col
|
||||
from colossalai.shardformer.layer.parallel_module import ParallelModule
|
||||
|
||||
|
@ -12,17 +14,51 @@ class BaichuanLMHeadLinear1D_Col(Linear1D_Col):
|
|||
def from_native_module(
|
||||
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
|
||||
) -> ParallelModule:
|
||||
LazyInitContext.materialize(module)
|
||||
module.in_features = module.weight.size(1)
|
||||
module.out_features = module.weight.size(0)
|
||||
module.bias = None
|
||||
module.weight.data = nn.functional.normalize(
|
||||
module.weight
|
||||
) # TODO(lry89757) This behavior may not apply to lazy init. When we use lazy init, the weight of shardformer is not the real weight.
|
||||
) # NOTE(lry89757) This behavior may not apply to lazy init. When we use lazy init, the weight of shardformer is not the real weight.
|
||||
# So we should rewrite our own load_from_state_dict of `BaichuanLMHeadLinear1D_Col` to fix this potential issue.
|
||||
|
||||
return Linear1D_Col.from_native_module(
|
||||
module,
|
||||
process_group,
|
||||
*args,
|
||||
# get the attributes
|
||||
in_features = module.in_features
|
||||
out_features = module.out_features
|
||||
bias = module.bias is not None
|
||||
device = module.weight.device
|
||||
# ensure only one process group is passed
|
||||
if isinstance(process_group, (list, tuple)):
|
||||
assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
|
||||
process_group = process_group[0]
|
||||
|
||||
tp_size = dist.get_world_size(process_group)
|
||||
if out_features < tp_size:
|
||||
return module
|
||||
|
||||
if out_features % tp_size != 0:
|
||||
raise ValueError(
|
||||
f"The size of out_features:{out_features} is not integer multiples of tensor parallel size: {tp_size}!"
|
||||
)
|
||||
|
||||
lmhead_1d = BaichuanLMHeadLinear1D_Col(
|
||||
in_features=in_features,
|
||||
out_features=out_features,
|
||||
bias=bias,
|
||||
device=device,
|
||||
process_group=process_group,
|
||||
weight=module.weight,
|
||||
bias_=module.bias,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return lmhead_1d
|
||||
|
||||
def _load_from_state_dict(
|
||||
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
||||
):
|
||||
state_dict[prefix + "weight"] = nn.functional.normalize(state_dict[prefix + "weight"])
|
||||
super()._load_from_state_dict(
|
||||
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
||||
)
|
||||
|
|
|
@ -70,7 +70,6 @@ class NopadBaichuanAttention(ParallelModule):
|
|||
attn_oproj (Linear1D_Row, optional): The Linear1D_Row o_proj. Defaults to None.
|
||||
"""
|
||||
ParallelModule.__init__(self)
|
||||
self.o_proj = attn_oproj
|
||||
|
||||
self.config = config
|
||||
self.num_heads = num_heads
|
||||
|
@ -78,6 +77,7 @@ class NopadBaichuanAttention(ParallelModule):
|
|||
self.head_dim = self.hidden_size // self.num_heads
|
||||
self.process_group = process_group
|
||||
self.W_pack = W_pack
|
||||
self.o_proj = attn_oproj
|
||||
self.use_cuda_kernel = model_shard_infer_config.use_cuda_kernel
|
||||
self.attention_backend = get_attention_backend(model_shard_infer_config)
|
||||
self.pre_attention_backend = get_pre_attention_backend(model_shard_infer_config)
|
||||
|
|
|
@ -284,6 +284,10 @@ class NopadLlamaMLP(LlamaMLP, ParallelModule):
|
|||
self.gate_up_weight = nn.Parameter(
|
||||
torch.stack([mlp_gproj_w.transpose(0, 1), mlp_uproj_w.transpose(0, 1)], dim=0)
|
||||
)
|
||||
self.gate_up_dict = {
|
||||
"gate_proj.weight": None,
|
||||
"up_proj.weight": None,
|
||||
} # used and delattr in load/shard of gate/up weight
|
||||
self.down_proj = mlp_dproj
|
||||
self.process_group = process_group
|
||||
|
||||
|
@ -321,44 +325,47 @@ class NopadLlamaMLP(LlamaMLP, ParallelModule):
|
|||
):
|
||||
# NOTE This is a hack to ensure we could load the right weight from LlamaMLP checkpoint due to the use of torch.stack(gate_weight, up_weight)
|
||||
|
||||
for hook in self._load_state_dict_pre_hooks.values():
|
||||
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||
if hasattr(self, "gate_up_dict"):
|
||||
for hook in self._load_state_dict_pre_hooks.values():
|
||||
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||
|
||||
persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
|
||||
local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
|
||||
local_state = {k: v for k, v in local_name_params if v is not None}
|
||||
persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
|
||||
local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
|
||||
local_state = {k: v for k, v in local_name_params if v is not None}
|
||||
|
||||
key = "gate_up_weight"
|
||||
k1 = "gate_proj.weight"
|
||||
k2 = "up_proj.weight"
|
||||
device_mesh = self.helper_layout.device_mesh
|
||||
sharding_spec = self.helper_layout.sharding_spec
|
||||
for weight_name in self.gate_up_dict:
|
||||
prefix_weight_name = prefix + weight_name
|
||||
if prefix_weight_name in state_dict.keys():
|
||||
w = distribute_tensor(state_dict[prefix_weight_name], device_mesh, sharding_spec)
|
||||
self.gate_up_dict[weight_name] = w.T
|
||||
|
||||
gate_w = state_dict[prefix + k1]
|
||||
up_w = state_dict[prefix + k2]
|
||||
if None not in self.gate_up_dict.values():
|
||||
# we've got all the weights of gate/up
|
||||
gate_up_w = torch.stack(list(self.gate_up_dict.values()), dim=0)
|
||||
|
||||
device_mesh = self.helper_layout.device_mesh
|
||||
sharding_spec = self.helper_layout.sharding_spec
|
||||
gate_w = distribute_tensor(gate_w, device_mesh, sharding_spec)
|
||||
up_w = distribute_tensor(up_w, device_mesh, sharding_spec)
|
||||
input_param = nn.Parameter(
|
||||
gate_up_w
|
||||
) # NOTE gate_up_weight doesn't have to be a distensor, Like input_param = sharded_tensor_to_param(input_param)
|
||||
|
||||
gate_up_w = torch.stack([gate_w.T, up_w.T], dim=0)
|
||||
key = "gate_up_weight"
|
||||
param = local_state.get(key, None)
|
||||
|
||||
input_param = nn.Parameter(
|
||||
gate_up_w
|
||||
) # NOTE gate_up_weight doesn't have to be a distensor, Like input_param = sharded_tensor_to_param(input_param)
|
||||
param = local_state[key]
|
||||
try:
|
||||
with torch.no_grad():
|
||||
param.copy_(input_param)
|
||||
except Exception as ex:
|
||||
error_msgs.append(
|
||||
'While copying the parameter named "{}", '
|
||||
"whose dimensions in the model are {} and "
|
||||
"whose dimensions in the checkpoint are {}, "
|
||||
"an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args)
|
||||
)
|
||||
|
||||
try:
|
||||
with torch.no_grad():
|
||||
param.copy_(input_param)
|
||||
except Exception as ex:
|
||||
error_msgs.append(
|
||||
'While copying the parameter named "{}", '
|
||||
"whose dimensions in the model are {} and "
|
||||
"whose dimensions in the checkpoint are {}, "
|
||||
"an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args)
|
||||
)
|
||||
del self.gate_up_dict
|
||||
|
||||
strict = False # to avoid unexpected_keys
|
||||
strict = False # to avoid unexpected_keys
|
||||
super()._load_from_state_dict(
|
||||
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
||||
)
|
||||
|
@ -429,7 +436,15 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule):
|
|||
self.helper_layout = (
|
||||
attn_qproj_w.dist_layout
|
||||
) # NOTE this is a hack for the right load/shard of qkv_weight(used in _load_from_state_dict)
|
||||
self.qkv_dict = {
|
||||
"q_proj.weight": None,
|
||||
"k_proj.weight": None,
|
||||
"v_proj.weight": None,
|
||||
} # used and delattr in load/shard of qkv weight
|
||||
else:
|
||||
self.helper_layout = (
|
||||
attn_qproj_w.dist_layout
|
||||
) # NOTE this is a hack for the right load/shard of qkv_weight(used in _load_from_state_dict)
|
||||
self.q_proj_weight = nn.Parameter(attn_qproj_w.transpose(0, 1).contiguous())
|
||||
self.k_proj_weight = nn.Parameter(attn_kproj_w.transpose(0, 1).contiguous())
|
||||
self.v_proj_weight = nn.Parameter(attn_vproj_w.transpose(0, 1).contiguous())
|
||||
|
@ -577,49 +592,83 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule):
|
|||
def _load_from_state_dict(
|
||||
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
||||
):
|
||||
if self.num_heads == self.num_key_value_heads:
|
||||
for hook in self._load_state_dict_pre_hooks.values():
|
||||
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||
|
||||
persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
|
||||
local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
|
||||
local_state = {k: v for k, v in local_name_params if v is not None}
|
||||
|
||||
device_mesh = self.helper_layout.device_mesh
|
||||
sharding_spec = self.helper_layout.sharding_spec
|
||||
|
||||
if self.num_heads == self.num_key_value_heads and hasattr(self, "qkv_dict"):
|
||||
# NOTE This is a hack to ensure we could load the right weight from LlamaAttention checkpoint due to the use of torch.stack(q_weight, k_weight, v_weight)
|
||||
for hook in self._load_state_dict_pre_hooks.values():
|
||||
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||
|
||||
persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
|
||||
local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
|
||||
local_state = {k: v for k, v in local_name_params if v is not None}
|
||||
|
||||
key = "qkv_weight"
|
||||
k1 = "q_proj.weight"
|
||||
k2 = "k_proj.weight"
|
||||
k3 = "v_proj.weight"
|
||||
q_w = state_dict[prefix + k1]
|
||||
k_w = state_dict[prefix + k2]
|
||||
v_w = state_dict[prefix + k3]
|
||||
|
||||
device_mesh = self.helper_layout.device_mesh
|
||||
sharding_spec = self.helper_layout.sharding_spec
|
||||
q_w = distribute_tensor(q_w, device_mesh, sharding_spec)
|
||||
k_w = distribute_tensor(k_w, device_mesh, sharding_spec)
|
||||
v_w = distribute_tensor(v_w, device_mesh, sharding_spec)
|
||||
# NOTE(@lry89757) We will load the sharded checkpoint file according to the weight map from *.index.json
|
||||
# Here we need the weight of q,k,v to stack the weights of q,k,v into one qkv weight.
|
||||
# Unfortunately, it is highly like that all weights of q,k,v are not in the same sharded checkpoint file(like meta-llama/llama3-70B)
|
||||
# so here we will stack them when we really collect all the three weights.
|
||||
for weight_name in self.qkv_dict:
|
||||
prefix_weight_name = prefix + weight_name
|
||||
if prefix_weight_name in state_dict.keys():
|
||||
w = distribute_tensor(state_dict[prefix_weight_name], device_mesh, sharding_spec)
|
||||
self.qkv_dict[weight_name] = w.T
|
||||
|
||||
qkv_w = torch.stack([q_w.T, k_w.T, v_w.T], dim=0)
|
||||
if None not in self.qkv_dict.values():
|
||||
# we've got all the weights of q, k, v
|
||||
qkv_w = torch.stack(list(self.qkv_dict.values()), dim=0)
|
||||
|
||||
input_param = nn.Parameter(
|
||||
qkv_w
|
||||
) # NOTE qkv_weight doesn't have to be a distensor, Like input_param = sharded_tensor_to_param(input_param)
|
||||
input_param = nn.Parameter(
|
||||
qkv_w
|
||||
) # NOTE qkv_weight doesn't have to be a distensor, Like input_param = sharded_tensor_to_param(input_param)
|
||||
|
||||
param = local_state[key]
|
||||
param = local_state[key]
|
||||
|
||||
try:
|
||||
with torch.no_grad():
|
||||
param.copy_(input_param)
|
||||
except Exception as ex:
|
||||
error_msgs.append(
|
||||
'While copying the parameter named "{}", '
|
||||
"whose dimensions in the model are {} and "
|
||||
"whose dimensions in the checkpoint are {}, "
|
||||
"an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args)
|
||||
)
|
||||
try:
|
||||
with torch.no_grad():
|
||||
param.copy_(input_param)
|
||||
except Exception as ex:
|
||||
error_msgs.append(
|
||||
'While copying the parameter named "{}", '
|
||||
"whose dimensions in the model are {} and "
|
||||
"whose dimensions in the checkpoint are {}, "
|
||||
"an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args)
|
||||
)
|
||||
|
||||
strict = False # to avoid unexpected_keys
|
||||
del self.qkv_dict
|
||||
|
||||
else:
|
||||
|
||||
def _load(origin_weight_name="q_proj.weight", local_weight_name="q_proj_weight"):
|
||||
if prefix + origin_weight_name in state_dict.keys():
|
||||
attn_qproj_w = state_dict[prefix + origin_weight_name]
|
||||
w = distribute_tensor(attn_qproj_w, device_mesh, sharding_spec)
|
||||
input_param = nn.Parameter(w.T)
|
||||
param = local_state[local_weight_name]
|
||||
try:
|
||||
with torch.no_grad():
|
||||
param.copy_(input_param)
|
||||
except Exception as ex:
|
||||
key = local_weight_name
|
||||
error_msgs.append(
|
||||
'While copying the parameter named "{}", '
|
||||
"whose dimensions in the model are {} and "
|
||||
"whose dimensions in the checkpoint are {}, "
|
||||
"an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args)
|
||||
)
|
||||
|
||||
if prefix + "q_proj.weight" in state_dict.keys():
|
||||
_load(origin_weight_name="q_proj.weight", local_weight_name="q_proj_weight")
|
||||
|
||||
if prefix + "k_proj.weight" in state_dict.keys():
|
||||
_load(origin_weight_name="k_proj.weight", local_weight_name="k_proj_weight")
|
||||
|
||||
if prefix + "v_proj.weight" in state_dict.keys():
|
||||
_load(origin_weight_name="v_proj.weight", local_weight_name="v_proj_weight")
|
||||
|
||||
strict = False # to avoid unexpected_keys
|
||||
super()._load_from_state_dict(
|
||||
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
||||
)
|
||||
|
|
|
@ -225,31 +225,41 @@ def _batch_send_recv_tensor(
|
|||
send_group: Optional[ProcessGroup],
|
||||
recv_group: Optional[ProcessGroup],
|
||||
current_device: Any,
|
||||
overlap_p2p: bool = True,
|
||||
send_first: bool = True,
|
||||
) -> Optional[Union[torch.Tensor, List[torch.Tensor]]]:
|
||||
buffer_recv = None
|
||||
if recv_tensor_metadata is not None:
|
||||
buffer_recv = _create_recv_buffer(recv_tensor_metadata, current_device)
|
||||
|
||||
ops = []
|
||||
if send_dst is not None and send_tensor_list is not None:
|
||||
assert send_group is not None
|
||||
_filling_ops_queue(send_tensor_list, dist.isend, send_dst, ops, send_group)
|
||||
if recv_src is not None and buffer_recv is not None:
|
||||
assert recv_group is not None
|
||||
_filling_ops_queue(buffer_recv, dist.irecv, recv_src, ops, recv_group)
|
||||
is_send = send_dst is not None and send_tensor_list is not None
|
||||
is_recv = recv_src is not None and buffer_recv is not None
|
||||
|
||||
if send_first:
|
||||
if is_send:
|
||||
assert send_group is not None
|
||||
_filling_ops_queue(send_tensor_list, dist.isend, send_dst, ops, send_group)
|
||||
if is_recv:
|
||||
assert recv_group is not None
|
||||
_filling_ops_queue(buffer_recv, dist.irecv, recv_src, ops, recv_group)
|
||||
else:
|
||||
if is_recv:
|
||||
assert recv_group is not None
|
||||
_filling_ops_queue(buffer_recv, dist.irecv, recv_src, ops, recv_group)
|
||||
if is_send:
|
||||
assert send_group is not None
|
||||
_filling_ops_queue(send_tensor_list, dist.isend, send_dst, ops, send_group)
|
||||
|
||||
if len(ops) > 0:
|
||||
reqs = dist.batch_isend_irecv(ops)
|
||||
for req in reqs:
|
||||
req.wait()
|
||||
|
||||
# Remove synchronization according to Pytorch's documentation
|
||||
# However, the Megatron-LM does synchronization here
|
||||
# https://github.com/microsoft/Megatron-DeepSpeed/blob/ef13d099c2a1609225a4ce4c1a1753cc76dd90a1/megatron/p2p_communication.py#L111-L112
|
||||
# In case there is potential error, uncomment the following `torch.cuda.synchronize()`
|
||||
# torch.cuda.synchronize()
|
||||
|
||||
return buffer_recv
|
||||
if not overlap_p2p:
|
||||
for req in reqs:
|
||||
req.wait()
|
||||
return buffer_recv, []
|
||||
else:
|
||||
return buffer_recv, reqs
|
||||
return None, []
|
||||
|
||||
|
||||
def _send_recv_serialization_object(
|
||||
|
@ -260,10 +270,11 @@ def _send_recv_serialization_object(
|
|||
recv_group: Optional[ProcessGroup],
|
||||
current_device: Any,
|
||||
is_nccl_backend: bool,
|
||||
send_first: bool = True,
|
||||
) -> Optional[P2PMetadata]:
|
||||
ops = []
|
||||
|
||||
send_object_tensor = None
|
||||
send_object_size_tensor = None
|
||||
if object is not None and send_dst is not None:
|
||||
if Version(torch.__version__) >= Version("1.13.0"):
|
||||
send_object_tensor, send_object_size_tensor = c10d._object_to_tensor(object, device=current_device)
|
||||
|
@ -274,43 +285,54 @@ def _send_recv_serialization_object(
|
|||
send_object_size_tensor = send_object_size_tensor.to(current_device)
|
||||
send_object_tensor = send_object_tensor.to(current_device)
|
||||
|
||||
_filling_ops_queue(send_object_size_tensor, dist.isend, send_dst, ops, send_group)
|
||||
|
||||
recv_object_size_tensor = None
|
||||
if recv_src is not None:
|
||||
recv_object_size_tensor = torch.empty(1, dtype=torch.long)
|
||||
if is_nccl_backend:
|
||||
recv_object_size_tensor = recv_object_size_tensor.to(current_device)
|
||||
_filling_ops_queue(recv_object_size_tensor, dist.irecv, recv_src, ops, recv_group)
|
||||
|
||||
if send_first:
|
||||
if send_object_size_tensor is not None:
|
||||
_filling_ops_queue(send_object_size_tensor, dist.isend, send_dst, ops, send_group)
|
||||
if recv_src is not None:
|
||||
_filling_ops_queue(recv_object_size_tensor, dist.irecv, recv_src, ops, recv_group)
|
||||
else:
|
||||
if recv_src is not None:
|
||||
_filling_ops_queue(recv_object_size_tensor, dist.irecv, recv_src, ops, recv_group)
|
||||
if send_object_size_tensor is not None:
|
||||
_filling_ops_queue(send_object_size_tensor, dist.isend, send_dst, ops, send_group)
|
||||
|
||||
if len(ops) > 0:
|
||||
reqs = dist.batch_isend_irecv(ops)
|
||||
for req in reqs:
|
||||
req.wait()
|
||||
|
||||
# See the comment in `_batch_send_recv_tensor`
|
||||
# torch.cuda.synchronize()
|
||||
req.wait() # This blocks the compute stream in torch
|
||||
|
||||
ops = []
|
||||
|
||||
if send_dst is not None and send_object_tensor is not None:
|
||||
_filling_ops_queue(send_object_tensor, dist.isend, send_dst, ops, send_group)
|
||||
is_send = send_dst is not None and send_object_tensor is not None
|
||||
is_recv = recv_src is not None and recv_object_size_tensor is not None
|
||||
|
||||
recv_object_tensor = None
|
||||
if recv_src is not None and recv_object_size_tensor is not None:
|
||||
if is_recv:
|
||||
recv_object_tensor = torch.empty(recv_object_size_tensor.item(), dtype=torch.uint8)
|
||||
if is_nccl_backend:
|
||||
recv_object_tensor = recv_object_tensor.to(current_device)
|
||||
_filling_ops_queue(recv_object_tensor, dist.irecv, recv_src, ops, recv_group)
|
||||
|
||||
if send_first:
|
||||
if is_send:
|
||||
_filling_ops_queue(send_object_tensor, dist.isend, send_dst, ops, send_group)
|
||||
if is_recv:
|
||||
_filling_ops_queue(recv_object_tensor, dist.irecv, recv_src, ops, recv_group)
|
||||
else:
|
||||
if is_recv:
|
||||
_filling_ops_queue(recv_object_tensor, dist.irecv, recv_src, ops, recv_group)
|
||||
if is_send:
|
||||
_filling_ops_queue(send_object_tensor, dist.isend, send_dst, ops, send_group)
|
||||
|
||||
if len(ops) > 0:
|
||||
reqs = dist.batch_isend_irecv(ops)
|
||||
for req in reqs:
|
||||
req.wait()
|
||||
|
||||
# See the comment in `_batch_send_recv_tensor`
|
||||
# torch.cuda.synchronize()
|
||||
|
||||
if recv_object_tensor is not None and recv_object_size_tensor is not None:
|
||||
recv_object_tensor = recv_object_tensor.type(torch.uint8)
|
||||
if recv_object_tensor.device != torch.device("cpu"):
|
||||
|
@ -328,11 +350,12 @@ def _communicate(
|
|||
object: Any,
|
||||
send_dst: Optional[int],
|
||||
recv_src: Optional[int],
|
||||
overlap_p2p: bool,
|
||||
send_group: Optional[ProcessGroup] = None,
|
||||
recv_group: Optional[ProcessGroup] = None,
|
||||
send_metadata: bool = True,
|
||||
metadata_recv: Optional[P2PMetadata] = None,
|
||||
send_prior_fallback: Optional[bool] = None,
|
||||
send_first: Optional[bool] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Send and receive object from send_dst and recv_src respectively
|
||||
|
@ -341,6 +364,7 @@ def _communicate(
|
|||
object (Any): object needed to be sent
|
||||
send_dst (int): rank of the destination
|
||||
recv_src (int): rank of the source
|
||||
overlap_p2p (bool): whether to overlap p2p communication with computation
|
||||
send_group (ProcessGroup, optional): process group of sender
|
||||
recv_group (ProcessGroup, optional): process group of receiver
|
||||
send_metadata (bool, optional): whether to send metadata
|
||||
|
@ -358,32 +382,10 @@ def _communicate(
|
|||
# NOTE: if object contains non-tensor objects, we have to send metadata
|
||||
metadata_send, tensor_objs = create_send_metadata(object, strict=False, return_tensor=True)
|
||||
send_metadata = send_metadata or len(metadata_send.non_tensor_obj_idx) > 0
|
||||
else:
|
||||
send_metadata = False
|
||||
|
||||
# NOTE: send & recv should be atomic operations. However, if we need to send metadata or receive metadata,
|
||||
# we are not able to do that (1. send & recv metadata 2. send & recv). So we need to split the send & recv into two parts in this case.
|
||||
if (send_dst is not None and recv_src is not None) and (send_metadata or metadata_recv is None):
|
||||
assert send_prior_fallback is not None, "Priority must be set if fallback happens"
|
||||
if send_prior_fallback:
|
||||
_communicate(object, send_dst=send_dst, recv_src=None, send_group=send_group, send_metadata=send_metadata)
|
||||
return _communicate(
|
||||
None, send_dst=None, recv_src=recv_src, recv_group=recv_group, metadata_recv=metadata_recv
|
||||
)
|
||||
else:
|
||||
recv_data = _communicate(
|
||||
None, send_dst=None, recv_src=recv_src, recv_group=recv_group, metadata_recv=metadata_recv
|
||||
)
|
||||
_communicate(object, send_dst=send_dst, recv_src=None, send_group=send_group, send_metadata=send_metadata)
|
||||
return recv_data
|
||||
|
||||
# NOTE: only the following 5 cases are valid:
|
||||
# 1. send() [needs extra metadata] and no recv()
|
||||
# 2. recv() [needs extra metadata] and no send()
|
||||
# 3. neither send() nor recv() need extra metadata
|
||||
assert not (send_dst is not None and send_metadata) or recv_src is None
|
||||
assert not (recv_src is not None and metadata_recv is None) or send_dst is None
|
||||
assert not (send_dst is not None and recv_src is not None) or (not send_metadata and metadata_recv is not None)
|
||||
assert not c10d._rank_not_in_group(send_group) and not c10d._rank_not_in_group(recv_group)
|
||||
|
||||
current_send_device, is_send_nccl_backend = _check_device(send_group)
|
||||
current_recv_device, is_recv_nccl_backend = _check_device(recv_group)
|
||||
|
||||
|
@ -402,14 +404,25 @@ def _communicate(
|
|||
recv_group=recv_group if metadata_recv is None else None,
|
||||
current_device=current_device,
|
||||
is_nccl_backend=is_nccl_backend,
|
||||
send_first=send_first if send_first != None else True,
|
||||
)
|
||||
assert metadata_recv is None or _metadata_recv is None
|
||||
assert (
|
||||
metadata_recv is None or _metadata_recv is None
|
||||
), "You shouldn't receive metadata when using the cached metadata"
|
||||
metadata_recv = _metadata_recv if metadata_recv is None else metadata_recv
|
||||
|
||||
# Send and receive data
|
||||
recv_tensor_metadata = None if metadata_recv is None else metadata_recv.tensor_metadata
|
||||
recv_tensor_objs = _batch_send_recv_tensor(
|
||||
tensor_objs, recv_tensor_metadata, send_dst, recv_src, send_group, recv_group, current_device
|
||||
recv_tensor_objs, wait_handles = _batch_send_recv_tensor(
|
||||
tensor_objs,
|
||||
recv_tensor_metadata,
|
||||
send_dst,
|
||||
recv_src,
|
||||
send_group,
|
||||
recv_group,
|
||||
current_device,
|
||||
overlap_p2p=overlap_p2p,
|
||||
send_first=send_first if send_first != None else True,
|
||||
)
|
||||
|
||||
if metadata_recv is not None:
|
||||
|
@ -424,33 +437,9 @@ def _communicate(
|
|||
for idx in non_tensor_obj_idx:
|
||||
recv_tensor_objs.insert(idx, non_tensor_objs.pop(0))
|
||||
recv_object = tree_unflatten(recv_tensor_objs, tree_spec)
|
||||
return recv_object, wait_handles
|
||||
|
||||
return recv_object
|
||||
|
||||
|
||||
def _send_object(object: Any, src: int, dst: int, group: ProcessGroup, **kwargs) -> None:
|
||||
"""send anything to dst rank
|
||||
|
||||
Args:
|
||||
object (Any): object needed to be sent
|
||||
dst (int): rank of the destination
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
_communicate(object, send_dst=dst, recv_src=None, send_group=group, **kwargs)
|
||||
|
||||
|
||||
def _recv_object(src: int, dst: int, group: ProcessGroup, **kwargs) -> Any:
|
||||
"""recv anything from src
|
||||
|
||||
Args:
|
||||
src (int): source rank of data. local rank will receive data from src rank.
|
||||
|
||||
Returns:
|
||||
Any: Object received from src.
|
||||
"""
|
||||
return _communicate(None, send_dst=None, recv_src=src, recv_group=group, **kwargs)
|
||||
return None, wait_handles
|
||||
|
||||
|
||||
def _p2p_comm(
|
||||
|
@ -532,10 +521,13 @@ def _p2p_comm(
|
|||
|
||||
|
||||
class PipelineP2PCommunication:
|
||||
def __init__(self, stage_manager: PipelineStageManager) -> None:
|
||||
def __init__(self, stage_manager: PipelineStageManager, overlap_p2p: bool = True) -> None:
|
||||
self.stage_manager = stage_manager
|
||||
self.overlap_p2p = overlap_p2p
|
||||
|
||||
def recv_forward(self, prev_rank: Optional[int] = None, metadata_recv: Optional[P2PMetadata] = None) -> Any:
|
||||
def recv_forward(
|
||||
self, prev_rank: Optional[int] = None, metadata_recv: Optional[P2PMetadata] = None
|
||||
) -> Tuple[Any, List]:
|
||||
"""Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
|
||||
|
||||
Args:
|
||||
|
@ -543,95 +535,186 @@ class PipelineP2PCommunication:
|
|||
|
||||
Returns:
|
||||
Any: The input tensor or input tensor list.
|
||||
List: List of handles for the communication requests, if overlap is enabled.
|
||||
"""
|
||||
if prev_rank is None:
|
||||
prev_rank = self.stage_manager.get_prev_rank()
|
||||
cur_rank = self.stage_manager.get_rank()
|
||||
input_tensor = _recv_object(
|
||||
prev_rank,
|
||||
cur_rank,
|
||||
self.stage_manager.get_p2p_process_group(prev_rank, cur_rank),
|
||||
input_tensor, wait_handles = _communicate(
|
||||
object=None,
|
||||
recv_src=prev_rank,
|
||||
send_dst=None,
|
||||
recv_group=self.stage_manager.get_p2p_process_group(),
|
||||
metadata_recv=metadata_recv,
|
||||
overlap_p2p=self.overlap_p2p,
|
||||
)
|
||||
|
||||
return input_tensor
|
||||
return input_tensor, wait_handles
|
||||
|
||||
def recv_backward(self, next_rank: Optional[int] = None, metadata_recv: Optional[P2PMetadata] = None) -> Any:
|
||||
def recv_backward(
|
||||
self, next_rank: Optional[int] = None, metadata_recv: Optional[P2PMetadata] = None
|
||||
) -> Tuple[Any, List]:
|
||||
"""Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
|
||||
|
||||
Args:
|
||||
next_rank (int, optional): The rank of the source of the tensor.
|
||||
|
||||
Returns:
|
||||
Any: The input gradient tensor or gradient tensor list.
|
||||
Any: The input tensor or input tensor list.
|
||||
List: List of handles for the communication requests, if overlap is enabled.
|
||||
"""
|
||||
if next_rank is None:
|
||||
next_rank = self.stage_manager.get_next_rank()
|
||||
cur_rank = self.stage_manager.get_rank()
|
||||
output_tensor_grad = _recv_object(
|
||||
next_rank,
|
||||
cur_rank,
|
||||
self.stage_manager.get_p2p_process_group(next_rank, cur_rank),
|
||||
|
||||
output_tensor_grad, wait_handles = _communicate(
|
||||
object=None,
|
||||
recv_src=next_rank,
|
||||
send_dst=None,
|
||||
recv_group=self.stage_manager.get_p2p_process_group(),
|
||||
metadata_recv=metadata_recv,
|
||||
overlap_p2p=self.overlap_p2p,
|
||||
)
|
||||
|
||||
return output_tensor_grad
|
||||
return output_tensor_grad, wait_handles
|
||||
|
||||
def send_forward(self, output_object: Any, next_rank: Optional[int] = None, send_metadata: bool = True) -> None:
|
||||
def send_forward(self, output_object: Any, next_rank: Optional[int] = None, send_metadata: bool = True) -> List:
|
||||
"""Sends the input tensor to the next stage in pipeline.
|
||||
|
||||
Args:
|
||||
output_object (Any): Object to be sent.
|
||||
next_rank (int, optional): The rank of the recipient of the tensor.
|
||||
|
||||
Returns:
|
||||
List: List of handles for the communication requests, if overlap is enabled.
|
||||
"""
|
||||
if next_rank is None:
|
||||
next_rank = self.stage_manager.get_next_rank()
|
||||
cur_rank = self.stage_manager.get_rank()
|
||||
_send_object(
|
||||
_, handles = _communicate(
|
||||
output_object,
|
||||
cur_rank,
|
||||
next_rank,
|
||||
self.stage_manager.get_p2p_process_group(cur_rank, next_rank),
|
||||
recv_src=None,
|
||||
send_dst=next_rank,
|
||||
send_group=self.stage_manager.get_p2p_process_group(),
|
||||
send_metadata=send_metadata,
|
||||
overlap_p2p=self.overlap_p2p,
|
||||
)
|
||||
return handles
|
||||
|
||||
def send_backward(self, input_object: Any, prev_rank: Optional[int] = None, send_metadata: bool = True) -> None:
|
||||
def send_backward(self, input_object: Any, prev_rank: Optional[int] = None, send_metadata: bool = True) -> List:
|
||||
"""Sends the gradient tensor to the previous stage in pipeline.
|
||||
|
||||
Args:
|
||||
input_object (Any): Object to be sent.
|
||||
prev_rank (int, optional): The rank of the recipient of the tensor
|
||||
|
||||
Returns:
|
||||
List: List of handles for the communication requests, if overlap is enabled.
|
||||
"""
|
||||
if prev_rank is None:
|
||||
prev_rank = self.stage_manager.get_prev_rank()
|
||||
cur_rank = self.stage_manager.get_rank()
|
||||
_send_object(
|
||||
_, handles = _communicate(
|
||||
input_object,
|
||||
cur_rank,
|
||||
prev_rank,
|
||||
self.stage_manager.get_p2p_process_group(cur_rank, prev_rank),
|
||||
recv_src=None,
|
||||
send_dst=prev_rank,
|
||||
send_group=self.stage_manager.get_p2p_process_group(),
|
||||
send_metadata=send_metadata,
|
||||
overlap_p2p=self.overlap_p2p,
|
||||
)
|
||||
return handles
|
||||
|
||||
def send_forward_recv_forward(
|
||||
self,
|
||||
output_object: Any,
|
||||
is_send: bool,
|
||||
is_recv: bool,
|
||||
send_first: bool,
|
||||
send_metadata: bool = True,
|
||||
metadata_recv: Optional[P2PMetadata] = None,
|
||||
) -> Tuple[Any, List]:
|
||||
"""Sends the input tensor to the next pipeline stage and copy the output tensor from the next pipeline stage
|
||||
|
||||
Args:
|
||||
output_object (Any): Object to be sent.
|
||||
is_send (bool): Whether to send the input tensor to the next pipeline stage.
|
||||
is_recv (bool): Whether to copy the output tensor from the next pipeline stage.
|
||||
send_first (bool): Whether to send before receive.
|
||||
send_metadata (bool, optional): Whether to send metadata.
|
||||
metadata_recv (P2PMetadata, optional): The cached metadata(size, type) of the object to be received.
|
||||
|
||||
Returns:
|
||||
Any: The input tensor or input tensor list.
|
||||
List: List of handles for the communication requests, if overlap is enabled.
|
||||
"""
|
||||
next_rank = self.stage_manager.get_next_rank() if is_send else None
|
||||
prev_rank = self.stage_manager.get_prev_rank() if is_recv else None
|
||||
group = self.stage_manager.get_p2p_process_group()
|
||||
return _communicate(
|
||||
output_object,
|
||||
send_dst=next_rank,
|
||||
recv_src=prev_rank,
|
||||
send_group=group if is_send else None,
|
||||
recv_group=group if is_recv else None,
|
||||
send_metadata=send_metadata if is_send else False,
|
||||
metadata_recv=metadata_recv if is_recv else None,
|
||||
send_first=send_first,
|
||||
overlap_p2p=self.overlap_p2p,
|
||||
)
|
||||
|
||||
def send_backward_recv_backward(
|
||||
self,
|
||||
input_object: Any,
|
||||
is_send: bool,
|
||||
is_recv: bool,
|
||||
send_first: bool,
|
||||
send_metadata: bool = True,
|
||||
metadata_recv: Optional[P2PMetadata] = None,
|
||||
) -> Tuple[Any, List]:
|
||||
"""Sends the gradient tensor to the previous pipeline stage and copy the gradient tensor from the previous pipeline stage
|
||||
|
||||
Args:
|
||||
input_object (Any): Object to be sent.
|
||||
is_send (bool): Whether to send the gradient tensor to the previous pipeline stage.
|
||||
is_recv (bool): Whether to copy the gradient tensor from the previous pipeline stage.
|
||||
send_first (bool): Whether to send before receive.
|
||||
send_metadata (bool, optional): Whether to send metadata.
|
||||
metadata_recv (P2PMetadata, optional): The cached metadata(size, type) of the object to be received.
|
||||
|
||||
Returns:
|
||||
Any: The input tensor or input tensor list.
|
||||
List: List of handles for the communication requests, if overlap is enabled.
|
||||
"""
|
||||
prev_rank = self.stage_manager.get_prev_rank() if is_send else None
|
||||
next_rank = self.stage_manager.get_next_rank() if is_recv else None
|
||||
|
||||
group = self.stage_manager.get_p2p_process_group()
|
||||
|
||||
return _communicate(
|
||||
input_object,
|
||||
send_dst=prev_rank,
|
||||
recv_src=next_rank,
|
||||
send_group=group if is_send else None,
|
||||
recv_group=group if is_recv else None,
|
||||
send_metadata=send_metadata if is_send else False,
|
||||
metadata_recv=metadata_recv if is_recv else None,
|
||||
send_first=send_first,
|
||||
overlap_p2p=self.overlap_p2p,
|
||||
)
|
||||
|
||||
def send_forward_recv_backward(
|
||||
self,
|
||||
input_object: Any,
|
||||
next_rank: Optional[int] = None,
|
||||
send_metadata: bool = True,
|
||||
metadata_recv: Optional[P2PMetadata] = None,
|
||||
send_prior_fallback: Optional[bool] = None,
|
||||
) -> Any:
|
||||
"""Sends the gradient tensor to and copy the gradient tensor from the next stage in pipeline
|
||||
send_first: Optional[bool] = None,
|
||||
) -> Tuple[Any, List]:
|
||||
"""Sends the gradient tensor to and copy the gradient tensor from the next pipeline stage
|
||||
|
||||
Args:
|
||||
input_object (Any): Object to be sent.
|
||||
next_rank (int, optional): The rank of the sender and recipient of the tensor
|
||||
"""
|
||||
if next_rank is None:
|
||||
next_rank = self.stage_manager.get_next_rank()
|
||||
|
||||
cur_rank = self.stage_manager.get_rank()
|
||||
group = self.stage_manager.get_p2p_process_group(cur_rank, next_rank)
|
||||
Returns:
|
||||
Any: The input tensor or input tensor list.
|
||||
List: List of handles for the communication requests, if overlap is enabled.
|
||||
"""
|
||||
next_rank = self.stage_manager.get_next_rank()
|
||||
group = self.stage_manager.get_p2p_process_group()
|
||||
return _communicate(
|
||||
input_object,
|
||||
next_rank,
|
||||
|
@ -640,28 +723,28 @@ class PipelineP2PCommunication:
|
|||
recv_group=group,
|
||||
send_metadata=send_metadata,
|
||||
metadata_recv=metadata_recv,
|
||||
send_prior_fallback=send_prior_fallback,
|
||||
send_first=send_first,
|
||||
overlap_p2p=False,
|
||||
)
|
||||
|
||||
def send_backward_recv_forward(
|
||||
self,
|
||||
input_object: Any,
|
||||
prev_rank: Optional[int] = None,
|
||||
send_metadata: bool = True,
|
||||
metadata_recv: Optional[P2PMetadata] = None,
|
||||
send_prior_fallback: Optional[bool] = None,
|
||||
) -> Any:
|
||||
send_first: Optional[bool] = None,
|
||||
) -> Tuple[Any, List]:
|
||||
"""Sends the gradient tensor to and copy the gradient tensor from the previous stage in pipeline
|
||||
|
||||
Args:
|
||||
input_object (Any): Object to be sent.
|
||||
prev_rank (int, optional): The rank of the sender and recipient of the tensor
|
||||
"""
|
||||
if prev_rank is None:
|
||||
prev_rank = self.stage_manager.get_prev_rank()
|
||||
|
||||
cur_rank = self.stage_manager.get_rank()
|
||||
group = self.stage_manager.get_p2p_process_group(prev_rank, cur_rank)
|
||||
Returns:
|
||||
Any: The input tensor or input tensor list.
|
||||
List: List of handles for the communication requests, if overlap is enabled.
|
||||
"""
|
||||
prev_rank = self.stage_manager.get_prev_rank()
|
||||
group = self.stage_manager.get_p2p_process_group()
|
||||
return _communicate(
|
||||
input_object,
|
||||
prev_rank,
|
||||
|
@ -670,7 +753,8 @@ class PipelineP2PCommunication:
|
|||
recv_group=group,
|
||||
send_metadata=send_metadata,
|
||||
metadata_recv=metadata_recv,
|
||||
send_prior_fallback=send_prior_fallback,
|
||||
send_first=send_first,
|
||||
overlap_p2p=False,
|
||||
)
|
||||
|
||||
def p2p_communicate(
|
||||
|
@ -679,7 +763,7 @@ class PipelineP2PCommunication:
|
|||
recv_pre: bool,
|
||||
next_rank: Optional[int] = None,
|
||||
comm_dtype: torch.dtype = torch.float16,
|
||||
) -> None:
|
||||
) -> Any:
|
||||
"""
|
||||
Sends the input tensor to the next stage in pipeline, using `P2Pop` in torch.
|
||||
|
||||
|
@ -689,12 +773,11 @@ class PipelineP2PCommunication:
|
|||
"""
|
||||
if next_rank is None:
|
||||
next_rank = self.stage_manager.get_next_rank()
|
||||
cur_rank = self.stage_manager.get_rank()
|
||||
recv_tensor = _p2p_comm(
|
||||
output_object,
|
||||
recv_pre,
|
||||
next_rank,
|
||||
self.stage_manager.get_p2p_process_group(cur_rank, next_rank),
|
||||
self.stage_manager.get_p2p_process_group(),
|
||||
comm_dtype,
|
||||
)
|
||||
return recv_tensor
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
from functools import partial
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Union
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.cuda
|
||||
import torch.distributed
|
||||
from torch.nn import Module, ModuleList
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
|
@ -16,6 +17,12 @@ from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_
|
|||
from .base import PipelineSchedule
|
||||
|
||||
|
||||
def _wait_p2p(wait_handles: List[torch.cuda.Event]) -> None:
|
||||
if wait_handles is not None:
|
||||
for req in wait_handles:
|
||||
req.wait()
|
||||
|
||||
|
||||
class InterleavedSchedule(PipelineSchedule):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -24,13 +31,15 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
num_microbatch: Optional[int] = None,
|
||||
microbatch_size: Optional[int] = None,
|
||||
enable_metadata_cache: bool = True,
|
||||
overlap_p2p: bool = True,
|
||||
) -> None:
|
||||
super().__init__(stage_manager)
|
||||
assert (
|
||||
num_microbatch is not None or microbatch_size is not None
|
||||
), "Either num_microbatch or microbatch_size should be provided"
|
||||
|
||||
self.comm = PipelineP2PCommunication(stage_manager)
|
||||
self.comm = PipelineP2PCommunication(stage_manager, overlap_p2p=overlap_p2p)
|
||||
self.overlap_p2p = overlap_p2p
|
||||
self.num_microbatch = num_microbatch
|
||||
self.microbatch_size = microbatch_size
|
||||
self.num_model_chunks = num_model_chunks
|
||||
|
@ -113,14 +122,17 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
Returns:
|
||||
int: The model chunk idx of the input microbatch_id
|
||||
"""
|
||||
assert microbatch_id < self.num_microbatch * self.num_model_chunks
|
||||
assert (
|
||||
microbatch_id < self.num_microbatch * self.num_model_chunks
|
||||
), f"microbatch_id {microbatch_id} is out of range ({self.num_microbatch * self.num_model_chunks})"
|
||||
microbatch_id_in_group = microbatch_id % (self.stage_manager.num_stages * self.num_model_chunks)
|
||||
model_chunk_id = microbatch_id_in_group // self.stage_manager.num_stages
|
||||
if not is_forward:
|
||||
# Reverse order
|
||||
model_chunk_id = self.num_model_chunks - model_chunk_id - 1
|
||||
return model_chunk_id
|
||||
|
||||
def recv_forward(self, model_chunk_id: int, prev_rank: int = None) -> Any:
|
||||
def recv_forward(self, model_chunk_id: int, prev_rank: int = None) -> Tuple[Any, List]:
|
||||
"""Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
|
||||
For interleaved 1F1B.
|
||||
|
||||
|
@ -130,16 +142,19 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
|
||||
Returns:
|
||||
Any: The input tensor or input tensor list.
|
||||
Any: The wait handles for the communication.
|
||||
"""
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
|
||||
if not self.stage_manager.is_first_stage():
|
||||
input_tensor = self.comm.recv_forward(prev_rank, metadata_recv=self.tensor_metadata_recv)
|
||||
input_tensor, wait_handles = self.comm.recv_forward(prev_rank, metadata_recv=self.tensor_metadata_recv)
|
||||
|
||||
if self.enable_metadata_cache and self.tensor_metadata_recv is None:
|
||||
self.tensor_metadata_recv = create_send_metadata(input_tensor)
|
||||
|
||||
return input_tensor
|
||||
return input_tensor, wait_handles
|
||||
return None, []
|
||||
|
||||
def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> Any:
|
||||
def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> Tuple[Any, List]:
|
||||
"""Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
|
||||
For interleaved 1F1B.
|
||||
|
||||
|
@ -149,16 +164,20 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
|
||||
Returns:
|
||||
Any: The input gradient tensor or gradient tensor list.
|
||||
Any: The wait handles for the communication.
|
||||
"""
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
|
||||
if not self.stage_manager.is_last_stage():
|
||||
output_tensor_grad = self.comm.recv_backward(next_rank, metadata_recv=self.grad_metadata_recv)
|
||||
output_tensor_grad, wait_handles = self.comm.recv_backward(
|
||||
next_rank, metadata_recv=self.grad_metadata_recv
|
||||
)
|
||||
if self.enable_metadata_cache and self.grad_metadata_recv is None:
|
||||
self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
|
||||
return output_tensor_grad, wait_handles
|
||||
|
||||
return output_tensor_grad
|
||||
return None, []
|
||||
|
||||
def send_forward(self, model_chunk_id: int, output_tensor: Any, next_rank: int = None) -> None:
|
||||
def send_forward(self, model_chunk_id: int, output_tensor: Any, next_rank: int = None) -> List:
|
||||
"""Sends the input tensor to the next stage in pipeline.
|
||||
For interleaved 1F1B.
|
||||
|
||||
|
@ -166,13 +185,18 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
model_chunk_id (int): The current model chunk idx.
|
||||
output_object (Any): Object to be sent.
|
||||
next_rank (int, optional): The rank of the recipient of the tensor.
|
||||
|
||||
Returns:
|
||||
Any: The wait handles for the communication.
|
||||
"""
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
|
||||
if not self.stage_manager.is_last_stage():
|
||||
self.comm.send_forward(output_tensor, next_rank, send_metadata=self.send_tensor_metadata)
|
||||
send_handles = self.comm.send_forward(output_tensor, next_rank, send_metadata=self.send_tensor_metadata)
|
||||
self.send_tensor_metadata = not self.enable_metadata_cache
|
||||
return send_handles
|
||||
return []
|
||||
|
||||
def send_backward(self, model_chunk_id: int, input_tensor_grad: Any, prev_rank: int = None) -> None:
|
||||
def send_backward(self, model_chunk_id: int, input_tensor_grad: Any, prev_rank: int = None) -> List:
|
||||
"""Sends the gradient tensor to the previous stage in pipeline.
|
||||
For interleaved 1F1B.
|
||||
|
||||
|
@ -180,99 +204,61 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
model_chunk_id (int): The current model chunk idx.
|
||||
input_object (Any): Object to be sent.
|
||||
prev_rank (int, optional): The rank of the recipient of the tensor
|
||||
|
||||
Returns:
|
||||
Any: The wait handles for the communication.
|
||||
"""
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
|
||||
if not self.stage_manager.is_first_stage():
|
||||
self.comm.send_backward(input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata)
|
||||
send_handles = self.comm.send_backward(
|
||||
input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata
|
||||
)
|
||||
self.send_grad_metadata = not self.enable_metadata_cache
|
||||
|
||||
def send_forward_recv_backward(
|
||||
self,
|
||||
model_chunk_id_send: int,
|
||||
model_chunk_id_recv: int,
|
||||
output_tensor: Any,
|
||||
next_rank: Optional[int] = None,
|
||||
send_prior_fallback: Optional[bool] = None,
|
||||
) -> Any:
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id_send):
|
||||
send_data = not self.stage_manager.is_last_stage()
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id_recv):
|
||||
recv_data = not self.stage_manager.is_last_stage()
|
||||
|
||||
if send_data and recv_data:
|
||||
if not self.send_forward_recv_backward and self.grad_metadata_recv is not None:
|
||||
send_prior_fallback = None # must not fallback
|
||||
output_tensor_grad = self.comm.send_forward_recv_backward(
|
||||
output_tensor,
|
||||
next_rank,
|
||||
send_metadata=self.send_tensor_metadata,
|
||||
metadata_recv=self.grad_metadata_recv,
|
||||
send_prior_fallback=send_prior_fallback,
|
||||
)
|
||||
self.send_tensor_metadata = not self.enable_metadata_cache
|
||||
if self.enable_metadata_cache and self.grad_metadata_recv is None:
|
||||
self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
|
||||
return output_tensor_grad
|
||||
|
||||
# send only or recv only
|
||||
self.send_forward(model_chunk_id_send, output_tensor)
|
||||
return self.recv_backward(model_chunk_id_recv)
|
||||
|
||||
def send_backward_recv_forward(
|
||||
self,
|
||||
model_chunk_id_send: int,
|
||||
model_chunk_id_recv: int,
|
||||
input_tensor_grad: Any,
|
||||
prev_rank: Optional[int] = None,
|
||||
send_prior_fallback: Optional[bool] = None,
|
||||
) -> Any:
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id_send):
|
||||
send_data = not self.stage_manager.is_first_stage()
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id_recv):
|
||||
recv_data = not self.stage_manager.is_first_stage()
|
||||
|
||||
if send_data and recv_data:
|
||||
if not self.send_backward_recv_backward and self.tensor_metadata_recv is not None:
|
||||
send_prior_fallback = None # must not fallback
|
||||
input_tensor = self.comm.send_backward_recv_forward(
|
||||
input_tensor_grad,
|
||||
prev_rank,
|
||||
send_metadata=self.send_grad_metadata,
|
||||
metadata_recv=self.tensor_metadata_recv,
|
||||
send_prior_fallback=send_prior_fallback,
|
||||
)
|
||||
self.send_grad_metadata = not self.enable_metadata_cache
|
||||
if self.enable_metadata_cache and self.tensor_metadata_recv is None:
|
||||
self.tensor_metadata_recv = create_send_metadata(input_tensor)
|
||||
return input_tensor
|
||||
|
||||
# send only or recv only
|
||||
self.send_backward(model_chunk_id_send, input_tensor_grad)
|
||||
return self.recv_forward(model_chunk_id_recv)
|
||||
return send_handles
|
||||
return []
|
||||
|
||||
def send_forward_recv_forward(
|
||||
self, model_chunk_id_send: int, model_chunk_id_recv: int, output_tensor: Any, send_prior: bool
|
||||
):
|
||||
if send_prior:
|
||||
self.send_forward(model_chunk_id_send, output_tensor)
|
||||
input_tensor = self.recv_forward(model_chunk_id_recv)
|
||||
else:
|
||||
input_tensor = self.recv_forward(model_chunk_id_recv)
|
||||
self.send_forward(model_chunk_id_send, output_tensor)
|
||||
self, model_chunk_id_send: int, model_chunk_id_recv: int, output_tensor: Any, send_first: bool = True
|
||||
) -> Tuple[Any, List]:
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id_send):
|
||||
is_send = not self.stage_manager.is_last_stage()
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id_recv):
|
||||
is_recv = not self.stage_manager.is_first_stage()
|
||||
input_tensor, wait_handles = self.comm.send_forward_recv_forward(
|
||||
output_tensor,
|
||||
is_send,
|
||||
is_recv,
|
||||
send_metadata=self.send_tensor_metadata,
|
||||
metadata_recv=self.tensor_metadata_recv,
|
||||
send_first=send_first,
|
||||
)
|
||||
# Cache metadata
|
||||
self.send_tensor_metadata = not self.enable_metadata_cache and is_send
|
||||
if is_recv and self.enable_metadata_cache and self.tensor_metadata_recv is None:
|
||||
self.tensor_metadata_recv = create_send_metadata(input_tensor)
|
||||
|
||||
return input_tensor
|
||||
return input_tensor, wait_handles
|
||||
|
||||
def send_backward_recv_backward(
|
||||
self, model_chunk_id_send: int, model_chunk_id_recv: int, input_tensor_grad: Any, send_prior: bool
|
||||
):
|
||||
if send_prior:
|
||||
self.send_backward(model_chunk_id_send, input_tensor_grad)
|
||||
output_tensor_grad = self.recv_backward(model_chunk_id_recv)
|
||||
else:
|
||||
output_tensor_grad = self.recv_backward(model_chunk_id_recv)
|
||||
self.send_backward(model_chunk_id_send, input_tensor_grad)
|
||||
|
||||
return output_tensor_grad
|
||||
self, model_chunk_id_send: int, model_chunk_id_recv: int, input_tensor_grad: Any, send_first: bool = True
|
||||
) -> Tuple[Any, List]:
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id_send):
|
||||
is_send = not self.stage_manager.is_first_stage()
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id_recv):
|
||||
is_recv = not self.stage_manager.is_last_stage()
|
||||
output_tensor_grad, wait_handles = self.comm.send_backward_recv_backward(
|
||||
input_tensor_grad,
|
||||
is_send,
|
||||
is_recv,
|
||||
send_metadata=self.send_grad_metadata,
|
||||
metadata_recv=self.grad_metadata_recv,
|
||||
send_first=send_first,
|
||||
)
|
||||
# Cache metadata
|
||||
self.send_grad_metadata = not self.enable_metadata_cache and is_send
|
||||
if is_recv and self.enable_metadata_cache and self.grad_metadata_recv is None:
|
||||
self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
|
||||
return output_tensor_grad, wait_handles
|
||||
|
||||
def forward_step(
|
||||
self,
|
||||
|
@ -294,10 +280,12 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
Returns:
|
||||
Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor).
|
||||
"""
|
||||
# Load input ids, attention mask and labels
|
||||
micro_batch = self.load_micro_batch(model_chunk_id=model_chunk_id)
|
||||
|
||||
# for the first stage, input_obj is None
|
||||
# for the non-first stage, input_obj is the output of the previous stage and it's must be a dict
|
||||
# for other stages, input_obj is the output of the previous stage containing hidden_states etc.
|
||||
# Only attention_mask from micro_batch is used
|
||||
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
|
||||
if isinstance(model_chunk, ModuleList):
|
||||
|
@ -381,23 +369,27 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
if return_loss and self.stage_manager.is_last_stage(ignore_chunk=True):
|
||||
accum_loss = torch.scalar_tensor(0, device=get_current_device())
|
||||
|
||||
fwd_wait_handles = []
|
||||
model_chunk_id = self.get_model_chunk_id(0, is_forward=True)
|
||||
input_obj = self.recv_forward(model_chunk_id)
|
||||
input_obj, fwd_wait_handles = self.recv_forward(model_chunk_id)
|
||||
|
||||
for i in range(self.num_microbatch * self.num_model_chunks):
|
||||
last_iteration = i == self.num_microbatch * self.num_model_chunks - 1
|
||||
last_batch = i == self.num_microbatch * self.num_model_chunks - 1
|
||||
model_chunk_id = self.get_model_chunk_id(i, is_forward=True)
|
||||
|
||||
# Wait until current input is received
|
||||
_wait_p2p(fwd_wait_handles)
|
||||
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
|
||||
|
||||
if not last_iteration:
|
||||
input_obj = self.send_forward_recv_forward(
|
||||
if not last_batch:
|
||||
input_obj, fwd_wait_handles = self.send_forward_recv_forward(
|
||||
model_chunk_id_send=model_chunk_id,
|
||||
model_chunk_id_recv=self.get_model_chunk_id(i + 1, is_forward=True),
|
||||
output_tensor=output_obj,
|
||||
send_prior=self.stage_manager.stage % 2 == 0,
|
||||
send_first=self.stage_manager.stage % 2 == 0,
|
||||
)
|
||||
else:
|
||||
self.send_forward(model_chunk_id, output_obj)
|
||||
fwd_wait_handles = self.send_forward(model_chunk_id, output_obj)
|
||||
|
||||
if outputs is not None:
|
||||
outputs = merge_batch(outputs)
|
||||
|
@ -420,7 +412,9 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
self.load_batch(data_iter)
|
||||
|
||||
num_microbatch = self.num_microbatch * self.num_model_chunks
|
||||
# Forward + until 1st backward
|
||||
num_warmup_microbatch = (self.stage_manager.num_stages - self.stage_manager.stage - 1) * 2
|
||||
# Steps needed to reach the last chunk
|
||||
num_warmup_microbatch += (self.num_model_chunks - 1) * self.stage_manager.num_stages
|
||||
num_warmup_microbatch = min(num_warmup_microbatch, num_microbatch)
|
||||
num_microbatch_remaining = num_microbatch - num_warmup_microbatch
|
||||
|
@ -435,35 +429,44 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
if return_loss and self.stage_manager.is_last_stage(ignore_chunk=True):
|
||||
accum_loss = torch.scalar_tensor(0, device=get_current_device())
|
||||
|
||||
bwd_wait_handles = []
|
||||
# Get the 1st input batch
|
||||
model_chunk_id = self.get_model_chunk_id(0, is_forward=True)
|
||||
input_obj = self.recv_forward(model_chunk_id)
|
||||
input_obj, fwd_wait_handles = self.recv_forward(model_chunk_id)
|
||||
|
||||
# Run warmup forward passes.
|
||||
for i in range(num_warmup_microbatch):
|
||||
last_iteration = i == num_warmup_microbatch - 1
|
||||
last_batch = i == num_warmup_microbatch - 1
|
||||
model_chunk_id = self.get_model_chunk_id(i, is_forward=True)
|
||||
|
||||
# Wait for input
|
||||
_wait_p2p(fwd_wait_handles)
|
||||
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
|
||||
input_objs[model_chunk_id].append(input_obj)
|
||||
output_objs[model_chunk_id].append(output_obj)
|
||||
|
||||
if last_iteration and num_microbatch_remaining == 0:
|
||||
self.send_forward(model_chunk_id, output_obj)
|
||||
if last_batch and num_microbatch_remaining == 0:
|
||||
fwd_wait_handles = self.send_forward(model_chunk_id, output_obj)
|
||||
else:
|
||||
input_obj = self.send_forward_recv_forward(
|
||||
input_obj, fwd_wait_handles = self.send_forward_recv_forward(
|
||||
model_chunk_id_send=model_chunk_id,
|
||||
model_chunk_id_recv=self.get_model_chunk_id(i + 1, is_forward=True),
|
||||
output_tensor=output_obj,
|
||||
send_prior=self.stage_manager.stage % 2 == 0,
|
||||
send_first=self.stage_manager.stage % 2 == 0,
|
||||
)
|
||||
|
||||
if num_microbatch_remaining > 0:
|
||||
model_chunk_id = self.get_model_chunk_id(0, is_forward=False)
|
||||
output_obj_grad = self.recv_backward(model_chunk_id)
|
||||
output_obj_grad, bwd_wait_handles = self.recv_backward(model_chunk_id)
|
||||
|
||||
# Run 1F1B in steady state.
|
||||
for i in range(num_microbatch_remaining):
|
||||
last_iteration = i == num_microbatch_remaining - 1
|
||||
fwd_batch_id = i + num_warmup_microbatch
|
||||
last_batch = i == num_microbatch_remaining - 1
|
||||
model_chunk_id = self.get_model_chunk_id(fwd_batch_id, is_forward=True)
|
||||
|
||||
model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatch, is_forward=True)
|
||||
# Wait for input.
|
||||
_wait_p2p(fwd_wait_handles)
|
||||
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
|
||||
# Add input_obj and output_obj to end of list.
|
||||
input_objs[model_chunk_id].append(input_obj)
|
||||
|
@ -473,64 +476,75 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
# Pop output_obj and output_obj from the start of the list for the backward pass.
|
||||
_input_obj = input_objs[model_chunk_id].pop(0)
|
||||
_output_obj = output_objs[model_chunk_id].pop(0)
|
||||
input_obj_grad = self.backward_step(optimizer, _input_obj, _output_obj, output_obj_grad)
|
||||
|
||||
# NOTE: perform 2x communication for forward and backward
|
||||
def send_forward_recv_backward():
|
||||
if last_iteration and num_microbatch == num_microbatch_remaining:
|
||||
model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatch, is_forward=True)
|
||||
self.send_forward(model_chunk_id, output_obj)
|
||||
# Helper functions
|
||||
def send_forward_recv_forward():
|
||||
if last_batch:
|
||||
model_chunk_id = self.get_model_chunk_id(fwd_batch_id, is_forward=True)
|
||||
wait_handles = self.send_forward(model_chunk_id, output_obj)
|
||||
return None, wait_handles
|
||||
else:
|
||||
output_obj_grad = self.send_forward_recv_backward(
|
||||
model_chunk_id_send=self.get_model_chunk_id(i + num_warmup_microbatch, is_forward=True),
|
||||
model_chunk_id_recv=self.get_model_chunk_id(i + 1, is_forward=False),
|
||||
input_obj, wait_handles = self.send_forward_recv_forward(
|
||||
model_chunk_id_send=self.get_model_chunk_id(fwd_batch_id, is_forward=True),
|
||||
model_chunk_id_recv=self.get_model_chunk_id(fwd_batch_id + 1, is_forward=True),
|
||||
output_tensor=output_obj,
|
||||
send_prior_fallback=self.stage_manager.stage % 2 == 0,
|
||||
send_first=self.stage_manager.stage % 2 == 0
|
||||
and i > 0, # Receive from warmup stage first in the first batch
|
||||
)
|
||||
return output_obj_grad
|
||||
return input_obj, wait_handles
|
||||
|
||||
def send_backward_recv_forward():
|
||||
if last_iteration:
|
||||
def send_backward_recv_backward():
|
||||
no_cooldown = num_microbatch == num_microbatch_remaining
|
||||
if last_batch and no_cooldown:
|
||||
model_chunk_id = self.get_model_chunk_id(i, is_forward=False)
|
||||
self.send_backward(model_chunk_id, input_obj_grad)
|
||||
wait_handles = self.send_backward(model_chunk_id, input_obj_grad)
|
||||
return None, wait_handles
|
||||
else:
|
||||
input_obj = self.send_backward_recv_forward(
|
||||
output_obj_grad, wait_handles = self.send_backward_recv_backward(
|
||||
model_chunk_id_send=self.get_model_chunk_id(i, is_forward=False),
|
||||
model_chunk_id_recv=self.get_model_chunk_id(i + num_warmup_microbatch + 1, is_forward=True),
|
||||
model_chunk_id_recv=self.get_model_chunk_id(i + 1, is_forward=False),
|
||||
input_tensor_grad=input_obj_grad,
|
||||
send_prior_fallback=self.stage_manager.stage % 2 == 0 and i > 0,
|
||||
send_first=self.stage_manager.stage % 2 == 0,
|
||||
)
|
||||
return input_obj
|
||||
return output_obj_grad, wait_handles
|
||||
|
||||
if self.stage_manager.stage % 2 == 0:
|
||||
output_obj_grad = send_forward_recv_backward()
|
||||
input_obj = send_backward_recv_forward()
|
||||
else:
|
||||
input_obj = send_backward_recv_forward()
|
||||
output_obj_grad = send_forward_recv_backward()
|
||||
input_obj, fwd_wait_handles = send_forward_recv_forward()
|
||||
# Wait for upstream grad
|
||||
_wait_p2p(bwd_wait_handles)
|
||||
input_obj_grad = self.backward_step(optimizer, _input_obj, _output_obj, output_obj_grad)
|
||||
# NOTE: It's documented by NCCL that running two concurrent communicators (batch_isend_irecv)
|
||||
# risks deadlock (https://docs.nvidia.com/deeplearning/nccl/archives/nccl_2134/user-guide/docs/usage/communicators.html)
|
||||
# however in practice this works fine, and Megatron does this too
|
||||
# (https://github.com/microsoft/Megatron-DeepSpeed/blob/bcedecd1ff788d4d363f3365fd396053a08d65be/megatron/core/pipeline_parallel/schedules.py#L774)
|
||||
# if deadlock, call _wait_p2p(fwd_wait_handles) here
|
||||
output_obj_grad, bwd_wait_handles = send_backward_recv_backward()
|
||||
|
||||
if num_microbatch_remaining == 0:
|
||||
model_chunk_id = self.get_model_chunk_id(0, is_forward=False)
|
||||
output_obj_grad = self.recv_backward(model_chunk_id)
|
||||
output_obj_grad, bwd_wait_handles = self.recv_backward(model_chunk_id)
|
||||
|
||||
# Run cooldown backward passes.
|
||||
for i in range(num_microbatch_remaining, num_microbatch):
|
||||
last_iteration = i == num_microbatch - 1
|
||||
last_batch = i == num_microbatch - 1
|
||||
model_chunk_id = self.get_model_chunk_id(i, is_forward=False)
|
||||
_input_obj = input_objs[model_chunk_id].pop(0)
|
||||
_output_obj = output_objs[model_chunk_id].pop(0)
|
||||
# output_obj_grad = self.recv_backward(model_chunk_id)
|
||||
input_obj_grad = self.backward_step(optimizer, _input_obj, _output_obj, output_obj_grad)
|
||||
|
||||
if not last_iteration:
|
||||
output_obj_grad = self.send_backward_recv_backward(
|
||||
# Wait for upstream grad
|
||||
_wait_p2p(bwd_wait_handles)
|
||||
# backward local grads
|
||||
input_obj_grad = self.backward_step(optimizer, _input_obj, _output_obj, output_obj_grad)
|
||||
if not last_batch:
|
||||
output_obj_grad, bwd_wait_handles = self.send_backward_recv_backward(
|
||||
model_chunk_id_send=self.get_model_chunk_id(i, is_forward=False),
|
||||
model_chunk_id_recv=self.get_model_chunk_id(i + 1, is_forward=False),
|
||||
input_tensor_grad=input_obj_grad,
|
||||
send_prior=self.stage_manager.stage % 2 == 0 and i > num_microbatch_remaining,
|
||||
send_first=self.stage_manager.stage % 2 == 0 and i > num_microbatch_remaining,
|
||||
)
|
||||
assert (not self.overlap_p2p) or len(bwd_wait_handles) > 0
|
||||
else:
|
||||
model_chunk_id = self.get_model_chunk_id(i, is_forward=False)
|
||||
self.send_backward(model_chunk_id, input_obj_grad)
|
||||
_ = self.send_backward(model_chunk_id, input_obj_grad)
|
||||
|
||||
assert all(len(v) == 0 for v in input_objs) and all(len(v) == 0 for v in output_objs)
|
||||
|
||||
|
|
|
@ -45,7 +45,8 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
num_microbatches is not None or microbatch_size is not None
|
||||
), "Either num_microbatches or microbatch_size should be provided"
|
||||
|
||||
self.comm = PipelineP2PCommunication(stage_manager)
|
||||
self.comm = PipelineP2PCommunication(stage_manager, overlap_p2p=False)
|
||||
|
||||
self.num_microbatches = num_microbatches
|
||||
self.microbatch_size = microbatch_size
|
||||
self.batch: Optional[Any] = None
|
||||
|
@ -124,7 +125,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
Any: The input tensor or input tensor list.
|
||||
"""
|
||||
if not self.stage_manager.is_first_stage():
|
||||
input_tensor = self.comm.recv_forward(prev_rank, metadata_recv=self.tensor_metadata_recv)
|
||||
input_tensor, _ = self.comm.recv_forward(prev_rank, metadata_recv=self.tensor_metadata_recv)
|
||||
if self.enable_metadata_cache and self.tensor_metadata_recv is None:
|
||||
self.tensor_metadata_recv = create_send_metadata(input_tensor)
|
||||
|
||||
|
@ -141,7 +142,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
Any: The input gradient tensor or gradient tensor list.
|
||||
"""
|
||||
if not self.stage_manager.is_last_stage():
|
||||
output_tensor_grad = self.comm.recv_backward(next_rank, metadata_recv=self.grad_metadata_recv)
|
||||
output_tensor_grad, _ = self.comm.recv_backward(next_rank, metadata_recv=self.grad_metadata_recv)
|
||||
if self.enable_metadata_cache and self.grad_metadata_recv is None:
|
||||
self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
|
||||
|
||||
|
@ -171,9 +172,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
self.comm.send_backward(input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata)
|
||||
self.send_grad_metadata = not self.enable_metadata_cache
|
||||
|
||||
def send_forward_recv_backward(
|
||||
self, output_tensor: Any, next_rank: int = None, send_prior_fallback: Optional[bool] = None
|
||||
) -> Any:
|
||||
def send_forward_recv_backward(self, output_tensor: Any, send_first: Optional[bool] = None) -> Any:
|
||||
"""Sends the input tensor to the next stage and copy the gradient tensor from the next stage in pipeline.
|
||||
For 1F1B.
|
||||
|
||||
|
@ -183,13 +182,12 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
"""
|
||||
if not self.stage_manager.is_last_stage():
|
||||
if not self.send_tensor_metadata and self.grad_metadata_recv is not None:
|
||||
send_prior_fallback = None # must not fallback
|
||||
output_tensor_grad = self.comm.send_forward_recv_backward(
|
||||
send_first = None
|
||||
output_tensor_grad, _ = self.comm.send_forward_recv_backward(
|
||||
output_tensor,
|
||||
next_rank,
|
||||
send_metadata=self.send_tensor_metadata,
|
||||
metadata_recv=self.grad_metadata_recv,
|
||||
send_prior_fallback=send_prior_fallback,
|
||||
send_first=send_first,
|
||||
)
|
||||
self.send_tensor_metadata = not self.enable_metadata_cache
|
||||
if self.enable_metadata_cache and self.grad_metadata_recv is None:
|
||||
|
@ -197,9 +195,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
|
||||
return output_tensor_grad
|
||||
|
||||
def send_backward_recv_forward(
|
||||
self, input_tensor_grad: Any, prev_rank: int = None, send_prior_fallback: Optional[bool] = None
|
||||
) -> Any:
|
||||
def send_backward_recv_forward(self, input_tensor_grad: Any, send_first: Optional[bool] = None) -> Any:
|
||||
"""Sends the gradient tensor to the previous stage and copy the input tensor from the previous stage in pipeline.
|
||||
For 1F1B.
|
||||
|
||||
|
@ -209,13 +205,12 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
"""
|
||||
if not self.stage_manager.is_first_stage():
|
||||
if not self.send_grad_metadata and self.tensor_metadata_recv is not None:
|
||||
send_prior_fallback = None # must not fallback
|
||||
input_tensor = self.comm.send_backward_recv_forward(
|
||||
send_first = None # must not fallback
|
||||
input_tensor, _ = self.comm.send_backward_recv_forward(
|
||||
input_tensor_grad,
|
||||
prev_rank,
|
||||
send_metadata=self.send_grad_metadata,
|
||||
metadata_recv=self.tensor_metadata_recv,
|
||||
send_prior_fallback=send_prior_fallback,
|
||||
send_first=send_first,
|
||||
)
|
||||
self.send_grad_metadata = not self.enable_metadata_cache
|
||||
if self.enable_metadata_cache and self.tensor_metadata_recv is None:
|
||||
|
@ -381,9 +376,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
last_iteration = i == (num_microbatches_remaining - 1)
|
||||
|
||||
output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs)
|
||||
output_obj_grad = self.send_forward_recv_backward(
|
||||
output_obj, send_prior_fallback=self.stage_manager.stage % 2 == 0
|
||||
)
|
||||
output_obj_grad = self.send_forward_recv_backward(output_obj, send_first=self.stage_manager.stage % 2 == 0)
|
||||
# Add input_obj and output_obj to end of list.
|
||||
input_objs.append(input_obj)
|
||||
output_objs.append(output_obj)
|
||||
|
@ -398,7 +391,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
self.send_backward(input_obj_grad)
|
||||
else:
|
||||
input_obj = self.send_backward_recv_forward(
|
||||
input_obj_grad, send_prior_fallback=self.stage_manager.stage % 2 == 0
|
||||
input_obj_grad, send_first=self.stage_manager.stage % 2 == 0
|
||||
)
|
||||
|
||||
# Run cooldown backward passes.
|
||||
|
|
|
@ -35,7 +35,7 @@ class PipelineStageManager:
|
|||
self.pipeline_axis = pipeline_axis
|
||||
self.prev_rank: Optional[Tuple[int, ...]] = None
|
||||
self.next_rank: Optional[Tuple[int, ...]] = None
|
||||
self.p2p_groups: Dict[Tuple[int, int], ProcessGroup] = {}
|
||||
self.p2p_groups: Dict[Tuple[int, ...], ProcessGroup] = {}
|
||||
if num_layers_per_stage is not None:
|
||||
assert len(num_layers_per_stage) == self.num_stages
|
||||
self.num_layers_per_stage = num_layers_per_stage
|
||||
|
@ -48,30 +48,14 @@ class PipelineStageManager:
|
|||
# the next rank of the last rank is rank0
|
||||
next_coord = coord[: self.pipeline_axis] + (coord[self.pipeline_axis] + 1,) + coord[self.pipeline_axis + 1 :]
|
||||
self.next_rank = self.pg_mesh.ravel(next_coord, self.pg_mesh.shape, mode="wrap")
|
||||
|
||||
# init p2p process groups
|
||||
stages = list(range(self.num_stages))
|
||||
for prev, cur in zip(stages[:-1], stages[1:]):
|
||||
group = self.pg_mesh.get_group_along_axis(self.pipeline_axis, [prev, cur])
|
||||
if self.stage in [prev, cur]:
|
||||
ranks_in_group = self.pg_mesh.get_ranks_in_group(group)
|
||||
self.p2p_groups[tuple(ranks_in_group)] = group
|
||||
|
||||
self.is_interleave = enable_interleave
|
||||
# for interleaved pipeline parallel, each device is responsible for multiple chunk of layers
|
||||
self.num_model_chunks: int = num_model_chunks
|
||||
if enable_interleave:
|
||||
# use circle p2p communication
|
||||
# add the process group of the first rank and the last rank
|
||||
group = self.pg_mesh.get_group_along_axis(self.pipeline_axis, [stages[0], stages[-1]])
|
||||
if self.stage in [stages[0], stages[-1]]:
|
||||
ranks_in_group = self.pg_mesh.get_ranks_in_group(group)
|
||||
self.p2p_groups[tuple(ranks_in_group)] = group
|
||||
|
||||
# for shardformer, hold stage indices of model
|
||||
self.stage_indices: List[Tuple[int, int]]
|
||||
# for shardformer, hold model chunk id
|
||||
self.model_chunk_id: Optional[int] = None
|
||||
# for shardformer, hold stage indices of model
|
||||
self.stage_indices: List[Tuple[int, int]]
|
||||
# for shardformer, hold model chunk id
|
||||
self.model_chunk_id: Optional[int] = None
|
||||
self.p2p_group = self.pg_mesh.get_group_along_axis(self.pipeline_axis)
|
||||
|
||||
def get_stage_index(
|
||||
self,
|
||||
|
@ -184,19 +168,12 @@ class PipelineStageManager:
|
|||
"""
|
||||
return self.next_rank
|
||||
|
||||
def get_p2p_process_group(self, first_rank: int, second_rank: int) -> ProcessGroup:
|
||||
def get_p2p_process_group(self) -> ProcessGroup:
|
||||
"""Get the p2p process group between two ranks. The order of the two ranks does not matter.
|
||||
|
||||
Args:
|
||||
first_rank (int): The first rank.
|
||||
second_rank (int): The second rank.
|
||||
|
||||
Returns:
|
||||
ProcessGroup: P2P process group between the two ranks.
|
||||
"""
|
||||
if first_rank > second_rank:
|
||||
first_rank, second_rank = second_rank, first_rank
|
||||
return self.p2p_groups[(first_rank, second_rank)]
|
||||
return self.p2p_group
|
||||
|
||||
def init_process_group_by_stages(self, stages: List[int]) -> ProcessGroup:
|
||||
"""Get the process group of the given stages.
|
||||
|
|
|
@ -674,6 +674,8 @@ class FusedLinear1D_Col(ParallelModule):
|
|||
process_group (`Union[ProcessGroup, List[ProcessGroup]]`): The process group to be used for weight sharding and communication.
|
||||
n_fused (int): The number of layers to be fused. In common, Q,K,V are fused in one weight.
|
||||
"""
|
||||
LazyInitContext.materialize(module)
|
||||
|
||||
# get the attributes
|
||||
in_features = module.in_features
|
||||
out_features = module.out_features
|
||||
|
|
|
@ -8,8 +8,15 @@ from transformers.modeling_outputs import (
|
|||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
Seq2SeqLMOutput,
|
||||
Seq2SeqModelOutput,
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from transformers.models.t5.modeling_t5 import (
|
||||
T5EncoderModel,
|
||||
T5ForConditionalGeneration,
|
||||
T5ForTokenClassification,
|
||||
T5Model,
|
||||
T5Stack,
|
||||
)
|
||||
from transformers.models.t5.modeling_t5 import T5EncoderModel, T5ForConditionalGeneration, T5Model, T5Stack
|
||||
from transformers.utils import logging
|
||||
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
|
@ -582,6 +589,71 @@ class T5PipelineForwards:
|
|||
|
||||
return outputs
|
||||
|
||||
@staticmethod
|
||||
def t5_for_token_classification_forward(
|
||||
self: T5ForTokenClassification,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
position_bias: Optional[torch.Tensor] = None,
|
||||
encoder_decoder_position_bias: Optional[torch.Tensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
backward_tensor_keys: Optional[List[str]] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
decoder_starting_stage: Optional[int] = None,
|
||||
) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]:
|
||||
r"""
|
||||
This function is modified on the basis of transformers.models.t5.modeling_t5.T5ForTokenClassification.forward.
|
||||
Please refer to original code of transformers for more details.
|
||||
```"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = T5PipelineForwards.t5_stack_forward(
|
||||
self.transformer.encoder,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
stage_manager=stage_manager,
|
||||
hidden_states=hidden_states,
|
||||
position_bias=position_bias,
|
||||
encoder_decoder_position_bias=encoder_decoder_position_bias,
|
||||
stage_index=stage_index,
|
||||
decoder_starting_stage=decoder_starting_stage,
|
||||
)
|
||||
if stage_manager.is_last_stage():
|
||||
sequence_output = outputs[0]
|
||||
|
||||
sequence_output = self.dropout(sequence_output)
|
||||
logits = self.classifier(sequence_output)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return TokenClassifierOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
def get_t5_flash_attention_forward():
|
||||
from transformers.models.t5.modeling_t5 import T5Attention
|
||||
|
|
|
@ -68,6 +68,9 @@ _POLICY_LIST = {
|
|||
file_name="t5", class_name="T5ForConditionalGenerationPolicy"
|
||||
),
|
||||
"transformers.models.t5.modeling_t5.T5EncoderModel": PolicyLocation(file_name="t5", class_name="T5EncoderPolicy"),
|
||||
"transformers.models.t5.modeling_t5.T5ForTokenClassification": PolicyLocation(
|
||||
file_name="t5", class_name="T5ForTokenClassificationPolicy"
|
||||
),
|
||||
# GPT2
|
||||
"transformers.models.gpt2.modeling_gpt2.GPT2Model": PolicyLocation(file_name="gpt2", class_name="GPT2ModelPolicy"),
|
||||
"transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel": PolicyLocation(
|
||||
|
|
|
@ -31,7 +31,13 @@ from ..modeling.t5 import (
|
|||
)
|
||||
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
__all__ = ["distribute_t5_layers", "T5ModelPolicy", "T5ForConditionalGenerationPolicy", "T5EncoderPolicy"]
|
||||
__all__ = [
|
||||
"distribute_t5_layers",
|
||||
"T5ModelPolicy",
|
||||
"T5ForConditionalGenerationPolicy",
|
||||
"T5EncoderPolicy",
|
||||
"T5ForTokenClassificationPolicy",
|
||||
]
|
||||
|
||||
|
||||
class T5BasePolicy(Policy):
|
||||
|
@ -312,9 +318,13 @@ class T5BasePolicy(Policy):
|
|||
assert self.pipeline_stage_manager is not None
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
|
||||
model = self.model
|
||||
encoder = self.model.encoder
|
||||
decoder = getattr(self.model, "decoder", None)
|
||||
if self.model.__class__.__name__ == "T5ForTokenClassification":
|
||||
model = self.model.transformer
|
||||
else:
|
||||
model = self.model
|
||||
|
||||
encoder = model.encoder
|
||||
decoder = getattr(model, "decoder", None)
|
||||
|
||||
num_encoder_layers = len(encoder.block)
|
||||
num_decoder_layers = len(decoder.block) if decoder else 0
|
||||
|
@ -353,7 +363,11 @@ class T5BasePolicy(Policy):
|
|||
raise ValueError("set_pipeline_forward method can only be called when pipeline parallel is enabled.")
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
|
||||
encoder = self.model.encoder
|
||||
if self.model.__class__.__name__ == "T5ForTokenClassification":
|
||||
encoder = self.model.transformer.encoder
|
||||
else:
|
||||
encoder = self.model.encoder
|
||||
|
||||
decoder = getattr(self.model, "decoder", None)
|
||||
|
||||
num_encoder_layers = len(encoder.block)
|
||||
|
@ -542,3 +556,46 @@ class T5EncoderPolicy(T5BasePolicy):
|
|||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
return []
|
||||
|
||||
|
||||
class T5ForTokenClassificationPolicy(T5EncoderPolicy):
|
||||
def module_policy(self):
|
||||
from transformers.models.t5.modeling_t5 import T5ForTokenClassification
|
||||
|
||||
policy = super().module_policy()
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
addon_module = {
|
||||
T5ForTokenClassification: ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=DropoutForParallelInput,
|
||||
)
|
||||
]
|
||||
)
|
||||
}
|
||||
policy.update(addon_module)
|
||||
if self.pipeline_stage_manager:
|
||||
self.set_pipeline_forward(
|
||||
model_cls=T5ForTokenClassification,
|
||||
new_forward=T5PipelineForwards.t5_for_token_classification_forward,
|
||||
policy=policy,
|
||||
)
|
||||
|
||||
return policy
|
||||
|
||||
def get_held_layers(self) -> List[nn.Module]:
|
||||
"""
|
||||
get pipeline layers for current stage
|
||||
"""
|
||||
held_layers = super().get_held_layers()
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||
held_layers.append(self.model.dropout)
|
||||
held_layers.append(self.model.classifier)
|
||||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
# no shared params for sequence classification model
|
||||
return []
|
||||
|
|
|
@ -403,9 +403,9 @@ class Chunk:
|
|||
self.shard_size, dtype=self.dtype, device=get_accelerator().get_current_device()
|
||||
)
|
||||
|
||||
input_list = list(torch.chunk(self.cuda_global_chunk, chunks=self.pg_size, dim=0))
|
||||
self.grad_reduce_work = dist.reduce_scatter(
|
||||
self.cuda_shard, input_list, group=self.torch_pg, async_op=async_op
|
||||
assert self.cuda_global_chunk.is_contiguous()
|
||||
self.grad_reduce_work = dist.reduce_scatter_tensor(
|
||||
self.cuda_shard, self.cuda_global_chunk, group=self.torch_pg, async_op=async_op
|
||||
)
|
||||
|
||||
if self.extra_dp_group is not None:
|
||||
|
@ -520,8 +520,10 @@ class Chunk:
|
|||
assert self.cuda_shard is not None
|
||||
|
||||
alloc_storage(self.cuda_global_chunk)
|
||||
gather_list = list(torch.chunk(input=self.cuda_global_chunk, chunks=self.pg_size, dim=0))
|
||||
work = dist.all_gather(gather_list, self.cuda_shard, self.torch_pg, async_op=async_op)
|
||||
assert self.cuda_global_chunk.is_contiguous()
|
||||
work = dist.all_gather_into_tensor(
|
||||
self.cuda_global_chunk, self.cuda_shard, self.torch_pg, async_op=async_op
|
||||
)
|
||||
|
||||
self.cuda_shard = None
|
||||
self.is_gathered = True
|
||||
|
|
|
@ -133,12 +133,12 @@ class ChunkManager:
|
|||
self.__sub_accessed_chunk(chunk)
|
||||
self.__add_memory_usage(chunk.memory_usage)
|
||||
|
||||
def move_chunk(self, chunk: Chunk, device: torch.device, force_copy: bool = False) -> None:
|
||||
def move_chunk(self, chunk: Chunk, device: torch.device, force_copy: bool = False, async_move=False) -> None:
|
||||
"""Move the shard of the chunk to the target device."""
|
||||
if not chunk.can_move or chunk.device_type == device.type:
|
||||
return
|
||||
self.__sub_memory_usage(chunk.memory_usage)
|
||||
chunk.shard_move(device, force_copy)
|
||||
chunk.shard_move(device, force_copy, non_blocking=async_move)
|
||||
self.__add_memory_usage(chunk.memory_usage)
|
||||
|
||||
def trans_tensor_state(self, tensor: torch.Tensor, state: TensorState) -> None:
|
||||
|
|
|
@ -387,6 +387,7 @@ class GeminiDDP(ModelWrapper):
|
|||
p: nn.Parameter,
|
||||
async_reduce_stream: Optional[torch.cuda.Stream] = None,
|
||||
):
|
||||
async_reduce_scatter = async_reduce_stream is not None
|
||||
setattr(p, "_gemini_reduced", True)
|
||||
empty_grad = torch.empty_like(grad)
|
||||
free_storage(empty_grad)
|
||||
|
@ -426,7 +427,7 @@ class GeminiDDP(ModelWrapper):
|
|||
async_reduce_stream.wait_stream(torch.cuda.current_stream())
|
||||
|
||||
with torch.cuda.stream(async_reduce_stream):
|
||||
reduced = chunk_manager.reduce_chunk(grad_chunk, async_op=(async_reduce_stream is not None))
|
||||
reduced = chunk_manager.reduce_chunk(grad_chunk, async_op=async_reduce_scatter)
|
||||
if reduced:
|
||||
grad_chunk.wait_async_reduce()
|
||||
if not chunk_manager.reuse_fp16_chunk:
|
||||
|
@ -447,9 +448,13 @@ class GeminiDDP(ModelWrapper):
|
|||
# record l2 norm for gradient clipping. flag is bound to fp16 chunk
|
||||
if chunk.l2_norm_flag:
|
||||
grad_chunk.set_l2_norm()
|
||||
chunk_manager.move_chunk(grad_chunk, grads_device[p], force_copy=True)
|
||||
chunk_manager.move_chunk(
|
||||
grad_chunk, grads_device[p], force_copy=True, async_move=async_reduce_scatter
|
||||
)
|
||||
if not (master_weights) or (enable_gradient_accumulation):
|
||||
chunk_manager.move_chunk(chunk, grads_device[p], force_copy=True)
|
||||
chunk_manager.move_chunk(
|
||||
chunk, grads_device[p], force_copy=True, async_move=async_reduce_scatter
|
||||
)
|
||||
return empty_grad
|
||||
|
||||
def zero_grad(self, set_to_none: bool = False) -> None:
|
||||
|
|
|
@ -1,3 +1,7 @@
|
|||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||
|
||||
|
||||
|
@ -6,6 +10,7 @@ class TensorBucket:
|
|||
self._max_size = size
|
||||
self._current_size = 0
|
||||
self._bucket = []
|
||||
self._write_back_pairs = {}
|
||||
|
||||
@property
|
||||
def max_size(self):
|
||||
|
@ -21,7 +26,7 @@ class TensorBucket:
|
|||
def is_empty(self):
|
||||
return len(self._bucket) == 0
|
||||
|
||||
def add_to_bucket(self, tensor, allow_oversize=False):
|
||||
def add_to_bucket(self, tensor, allow_oversize=False, write_back_tensor: Optional[torch.Tensor] = None):
|
||||
tensor_size = tensor.numel()
|
||||
|
||||
if not allow_oversize and self.will_exceed_max_size(tensor_size):
|
||||
|
@ -30,6 +35,8 @@ class TensorBucket:
|
|||
|
||||
self._bucket.append(tensor)
|
||||
self._current_size += tensor_size
|
||||
write_back_tensor = write_back_tensor if write_back_tensor is not None else tensor
|
||||
self._write_back_pairs[tensor] = write_back_tensor
|
||||
|
||||
def will_exceed_max_size(self, tensor_size):
|
||||
expected_size = self._current_size + tensor_size
|
||||
|
@ -40,12 +47,30 @@ class TensorBucket:
|
|||
|
||||
def empty(self):
|
||||
self._bucket = []
|
||||
self._size = 0
|
||||
self._current_size = 0
|
||||
self._write_back_pairs = {}
|
||||
|
||||
def flatten(self):
|
||||
return _flatten_dense_tensors(self._bucket)
|
||||
|
||||
def unflatten(self, flat_tensor):
|
||||
return _unflatten_dense_tensors(flat_tensor, self._bucket)
|
||||
|
||||
def unflatten_and_copy(self, flat_tensor):
|
||||
unflattened_tensor_list = _unflatten_dense_tensors(flat_tensor, self._bucket)
|
||||
unflattened_tensor_list = self.unflatten(flat_tensor)
|
||||
for old, new in zip(self._bucket, unflattened_tensor_list):
|
||||
old.copy_(new)
|
||||
|
||||
def all_gather(self, group=None):
|
||||
flat = self.flatten()
|
||||
buffers = [torch.empty_like(flat) for _ in range(dist.get_world_size(group))]
|
||||
dist.all_gather(buffers, flat, group=group)
|
||||
unflat_buffers = [self.unflatten(buffer) for buffer in buffers]
|
||||
# transpose the list of list
|
||||
unflat_buffers = list(map(list, zip(*unflat_buffers)))
|
||||
for unflat_shards, tensor in zip(unflat_buffers, self._bucket):
|
||||
write_back_tensor = self._write_back_pairs[tensor]
|
||||
write_back_tensor.data.copy_(
|
||||
_flatten_dense_tensors(unflat_shards)[: write_back_tensor.numel()].reshape_as(write_back_tensor)
|
||||
)
|
||||
self.empty()
|
||||
|
|
|
@ -23,7 +23,7 @@ from colossalai.logging import get_dist_logger
|
|||
from colossalai.tensor.moe_tensor.api import is_moe_tensor
|
||||
|
||||
from ._utils import calculate_global_norm_from_list, flatten, has_inf_or_nan, release_param_grad, sync_tensor
|
||||
from .bookkeeping import BucketStore, GradientStore, ParameterStore
|
||||
from .bookkeeping import BucketStore, GradientStore, ParameterStore, TensorBucket
|
||||
|
||||
|
||||
class LowLevelZeroFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):
|
||||
|
@ -694,34 +694,33 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
for group_id in range(self.num_param_groups):
|
||||
release_param_grad(self._master_param_groups_of_current_rank[group_id])
|
||||
|
||||
tensor_bucket = TensorBucket(self._bucket_store.reduce_bucket_size)
|
||||
moe_tensor_bucket = TensorBucket(self._bucket_store.reduce_bucket_size)
|
||||
|
||||
# update working partition updated by the current rank
|
||||
device = get_accelerator().get_current_device()
|
||||
for group_id in range(self.num_param_groups):
|
||||
master_working_param = self.optim.param_groups[group_id]["params"]
|
||||
for idx, splited_param in enumerate(master_working_param):
|
||||
working_param = real_working_params[group_id][idx]
|
||||
param_to_gather = splited_param.to(device).to(self._dtype)
|
||||
if self._bucket_store.moe_extra_dp_pg is not None and is_moe_tensor(working_param):
|
||||
all_splited_param = [
|
||||
torch.zeros(splited_param.shape, device=device, dtype=self._dtype)
|
||||
for _ in range(self._bucket_store.moe_extra_dp_pg_size)
|
||||
]
|
||||
dist.all_gather(
|
||||
all_splited_param,
|
||||
splited_param.to(device).to(self._dtype),
|
||||
group=self._bucket_store.moe_extra_dp_pg,
|
||||
)
|
||||
try:
|
||||
moe_tensor_bucket.add_to_bucket(param_to_gather, write_back_tensor=working_param)
|
||||
except RuntimeError:
|
||||
moe_tensor_bucket.all_gather(self._bucket_store.moe_extra_dp_pg)
|
||||
moe_tensor_bucket.add_to_bucket(param_to_gather, write_back_tensor=working_param)
|
||||
else:
|
||||
all_splited_param = [
|
||||
torch.zeros(splited_param.shape, device=device, dtype=self._dtype)
|
||||
for _ in range(self._bucket_store.zero_world_size)
|
||||
]
|
||||
dist.all_gather(
|
||||
all_splited_param,
|
||||
splited_param.to(device).to(self._dtype),
|
||||
group=self._bucket_store.torch_pg,
|
||||
)
|
||||
working_param.data.copy_(flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param))
|
||||
try:
|
||||
tensor_bucket.add_to_bucket(param_to_gather, write_back_tensor=working_param)
|
||||
except RuntimeError:
|
||||
tensor_bucket.all_gather(self._bucket_store.moe_extra_dp_pg)
|
||||
tensor_bucket.add_to_bucket(param_to_gather, write_back_tensor=working_param)
|
||||
self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id]
|
||||
if not moe_tensor_bucket.is_empty():
|
||||
moe_tensor_bucket.all_gather(self._bucket_store.moe_extra_dp_pg)
|
||||
if not tensor_bucket.is_empty():
|
||||
tensor_bucket.all_gather(self._bucket_store.torch_pg)
|
||||
|
||||
def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> float:
|
||||
r"""
|
||||
|
|
|
@ -9,6 +9,7 @@
|
|||
<a href="https://www.colossalai.org/"> 文档 </a> |
|
||||
<a href="https://github.com/hpcaitech/ColossalAI/tree/main/examples"> 例程 </a> |
|
||||
<a href="https://github.com/hpcaitech/ColossalAI/discussions"> 论坛 </a> |
|
||||
<a href="https://cloud.luchentech.com/">潞晨云 </a> |
|
||||
<a href="https://hpc-ai.com/blog"> 博客 </a></h3>
|
||||
|
||||
[](https://github.com/hpcaitech/ColossalAI/stargazers)
|
||||
|
@ -127,6 +128,8 @@ Colossal-AI 为您提供了一系列并行组件。我们的目标是让您的
|
|||
[[博客]](https://hpc-ai.com/blog/open-sora-from-hpc-ai-tech-team-continues-open-source-generate-any-16-second-720p-hd-video-with-one-click-model-weights-ready-to-use)
|
||||
[[模型权重]](https://github.com/hpcaitech/Open-Sora?tab=readme-ov-file#model-weights)
|
||||
[[演示样例]](https://github.com/hpcaitech/Open-Sora?tab=readme-ov-file#-latest-demo)
|
||||
[[潞晨云]](https://cloud.luchentech.com/)
|
||||
[[OpenSora镜像]](https://cloud.luchentech.com/doc/docs/image/open-sora/)
|
||||
|
||||
<div align="center">
|
||||
<a href="https://www.bilibili.com/video/BV1Fm421G7bV">
|
||||
|
@ -135,6 +138,8 @@ Colossal-AI 为您提供了一系列并行组件。我们的目标是让您的
|
|||
</div>
|
||||
|
||||
### Colossal-LLaMA-2
|
||||
[[潞晨云]](https://cloud.luchentech.com/)
|
||||
[[LLaMA3 镜像]](https://cloud.luchentech.com/doc/docs/image/llama)
|
||||
|
||||
- 7B:千元预算半天训练,效果媲美主流大模型,开源可商用中文LLaMA-2
|
||||
[[代码]](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Colossal-LLaMA-2)
|
||||
|
@ -265,7 +270,9 @@ Colossal-AI 为您提供了一系列并行组件。我们的目标是让您的
|
|||
</p>
|
||||
|
||||
- 700亿参数LLaMA3训练加速18%
|
||||
[[code]](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/llama)
|
||||
[[代码]](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/llama)
|
||||
[[潞晨云]](https://cloud.luchentech.com/)
|
||||
[[LLaMA3 镜像]](https://cloud.luchentech.com/doc/docs/image/llama)
|
||||
|
||||
### LLaMA2
|
||||
<p align="center">
|
||||
|
@ -378,6 +385,8 @@ Colossal-AI 为您提供了一系列并行组件。我们的目标是让您的
|
|||
- AI大模型推理速度部分接近翻倍,与vLLM的离线推理性能相比
|
||||
[[代码]](https://github.com/hpcaitech/ColossalAI/tree/main/colossalai/inference)
|
||||
[[博客]](https://hpc-ai.com/blog/colossal-inference)
|
||||
[[潞晨云]](https://cloud.luchentech.com/)
|
||||
[[LLaMA3 镜像]](https://cloud.luchentech.com/doc/docs/image/llama)
|
||||
|
||||
### Grok-1
|
||||
<p id="Grok-1" align="center">
|
||||
|
|
|
@ -1,4 +1,12 @@
|
|||
# Colossal-AI Examples
|
||||
<div align="center">
|
||||
|
||||
<h3>
|
||||
<a href="https://cloud.luchentech.com/">GPU Cloud Playground </a> </a> |
|
||||
<a href="https://cloud.luchentech.com/doc/docs/intro"> Playground Document </a>
|
||||
</h3>
|
||||
|
||||
</div>
|
||||
|
||||
## Table of Contents
|
||||
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
import argparse
|
||||
import resource
|
||||
import time
|
||||
import warnings
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from data_utils import RandomDataset
|
||||
from model_utils import format_numel_str, get_model_numel
|
||||
from performance_evaluator import PerformanceEvaluator, get_profile_context
|
||||
|
@ -21,11 +23,19 @@ from colossalai.lazy import LazyInitContext
|
|||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.shardformer import PipelineGradientCheckpointConfig
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
# ==============================
|
||||
# Constants
|
||||
# ==============================
|
||||
|
||||
MODEL_CONFIGS = {
|
||||
"100m": LlamaConfig(
|
||||
max_position_embeddings=4096,
|
||||
num_hidden_layers=4,
|
||||
num_attention_heads=32,
|
||||
intermediate_size=2048,
|
||||
hidden_size=1024,
|
||||
),
|
||||
"7b": LlamaConfig(max_position_embeddings=4096),
|
||||
"13b": LlamaConfig(
|
||||
hidden_size=5120,
|
||||
|
@ -58,6 +68,9 @@ def main():
|
|||
default="gemini",
|
||||
help="Choose which plugin to use",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--overlap", action="store_true", help="Overlap communication with computation in Pipeline Parallel."
|
||||
)
|
||||
parser.add_argument("-b", "--batch_size", type=int, default=2, help="Batch size")
|
||||
parser.add_argument("-s", "--num_steps", type=int, default=5, help="Number of steps to run")
|
||||
parser.add_argument("-i", "--ignore_steps", type=int, default=2, help="Number of steps to ignore")
|
||||
|
@ -78,11 +91,13 @@ def main():
|
|||
parser.add_argument("--mbs", type=int, default=1, help="Micro batch size of pipeline parallel")
|
||||
parser.add_argument("--zero", type=int, default=0, help="Zero Stage when hybrid plugin is enabled")
|
||||
parser.add_argument("--custom-ckpt", action="store_true", help="Customize checkpoint", default=False)
|
||||
parser.add_argument("--profile", action="store_true", help="Enable profiling", default=False)
|
||||
parser.add_argument(
|
||||
"--disable-async-reduce", action="store_true", help="Disable the asynchronous reduce operation", default=False
|
||||
)
|
||||
|
||||
parser.add_argument("--pp_style", default="1f1b", choices=["1f1b", "interleaved"])
|
||||
parser.add_argument("--n_chunks", default=1, help="number of model chunks", type=eval)
|
||||
parser.add_argument("--profile", action="store_true", help="Profile the code", default=False)
|
||||
parser.add_argument("--disable-async-reduce", action="store_true", help="Disable the asynchronous reduce operation")
|
||||
parser.add_argument("--prefetch_num", type=int, default=0, help="chunk prefetch max number")
|
||||
parser.add_argument("--no_cache", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
colossalai.launch_from_torch()
|
||||
|
@ -98,6 +113,7 @@ def main():
|
|||
num_ckpt_layers_per_stage=[19, 19, 19, 13],
|
||||
),
|
||||
"num_layers_per_stage": [19, 20, 20, 21],
|
||||
"pp_style": "interleaved",
|
||||
}
|
||||
if args.custom_ckpt
|
||||
else {}
|
||||
|
@ -174,6 +190,8 @@ def main():
|
|||
plugin = HybridParallelPlugin(
|
||||
tp_size=args.tp,
|
||||
pp_size=args.pp,
|
||||
pp_style=args.pp_style,
|
||||
num_model_chunks=args.n_chunks,
|
||||
zero_stage=args.zero,
|
||||
sp_size=args.sp,
|
||||
enable_sequence_parallelism=args.sp > 1,
|
||||
|
@ -182,12 +200,16 @@ def main():
|
|||
microbatch_size=args.mbs,
|
||||
precision="bf16",
|
||||
dp_outside=False,
|
||||
overlap_p2p=args.overlap,
|
||||
enable_metadata_cache=not args.no_cache,
|
||||
**hybrid_kwargs,
|
||||
)
|
||||
elif args.plugin == "3d_cpu":
|
||||
plugin = HybridParallelPlugin(
|
||||
tp_size=args.tp,
|
||||
pp_size=args.pp,
|
||||
pp_style=args.pp_style,
|
||||
num_model_chunks=args.n_chunks,
|
||||
zero_stage=args.zero,
|
||||
cpu_offload=True,
|
||||
enable_fused_normalization=torch.cuda.is_available(),
|
||||
|
@ -195,6 +217,7 @@ def main():
|
|||
microbatch_size=args.mbs,
|
||||
initial_scale=2**8,
|
||||
precision="bf16",
|
||||
overlap_p2p=args.overlap,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown plugin {args.plugin}")
|
||||
|
@ -210,10 +233,11 @@ def main():
|
|||
config = MODEL_CONFIGS[args.config]
|
||||
else:
|
||||
config = AutoConfig.from_pretrained(args.config, trust_remote_code=True)
|
||||
torch.cuda.manual_seed(42)
|
||||
dataset = RandomDataset(
|
||||
num_samples=args.batch_size * args.num_steps * dp_size, max_length=args.max_length, vocab_size=config.vocab_size
|
||||
)
|
||||
dataloader = plugin.prepare_dataloader(dataset, batch_size=args.batch_size, shuffle=True, drop_last=True)
|
||||
dataloader = plugin.prepare_dataloader(dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, seed=42)
|
||||
|
||||
# ==============================
|
||||
# Initialize Model and Optimizer
|
||||
|
@ -229,8 +253,13 @@ def main():
|
|||
init_kwargs["empty_init"] = False
|
||||
|
||||
with init_ctx:
|
||||
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True, **init_kwargs)
|
||||
|
||||
model = AutoModelForCausalLM.from_config(
|
||||
config,
|
||||
trust_remote_code=True,
|
||||
**init_kwargs,
|
||||
attn_implementation="flash_attention_2",
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
if args.grad_checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
if config.model_type == "chatglm":
|
||||
|
@ -251,6 +280,7 @@ def main():
|
|||
optimizer = HybridAdam(model.parameters())
|
||||
torch.set_default_dtype(torch.bfloat16)
|
||||
model, optimizer, _, dataloader, _ = booster.boost(model, optimizer, dataloader=dataloader)
|
||||
|
||||
torch.set_default_dtype(torch.float)
|
||||
coordinator.print_on_master(
|
||||
f"Booster init max CUDA memory: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB"
|
||||
|
@ -261,7 +291,7 @@ def main():
|
|||
|
||||
with get_profile_context(
|
||||
args.profile,
|
||||
1,
|
||||
args.ignore_steps,
|
||||
len(dataloader) - 1,
|
||||
save_dir=f"profile/{time.strftime('%H:%M', time.localtime())}-{args.plugin}-llama-{args.config}",
|
||||
) as prof:
|
||||
|
@ -269,15 +299,19 @@ def main():
|
|||
data_iter = iter(dataloader)
|
||||
for step in tqdm(range(len(dataloader)), desc="Step", disable=not coordinator.is_master()):
|
||||
performance_evaluator.on_step_start(step)
|
||||
booster.execute_pipeline(
|
||||
outputs = booster.execute_pipeline(
|
||||
data_iter,
|
||||
model,
|
||||
criterion=lambda outputs, inputs: outputs[0],
|
||||
optimizer=optimizer,
|
||||
return_loss=False,
|
||||
return_loss=True,
|
||||
)
|
||||
loss = outputs["loss"]
|
||||
if dist.get_rank() == dist.get_world_size() - 1:
|
||||
print(f"Step {step} loss: {loss}")
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
performance_evaluator.on_step_end(input_ids=torch.empty(args.batch_size, args.max_length))
|
||||
prof.step()
|
||||
else:
|
||||
|
@ -288,6 +322,7 @@ def main():
|
|||
booster.backward(loss, optimizer)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
performance_evaluator.on_step_end(**batch)
|
||||
prof.step()
|
||||
|
||||
|
|
|
@ -40,6 +40,14 @@ def data_gen_for_t5_model():
|
|||
return data
|
||||
|
||||
|
||||
def data_gen_for_token_classification():
|
||||
# token classification data gen
|
||||
# `labels` is the type not the token id for token classification, 0 or 1
|
||||
data = data_gen_for_encoder_only()
|
||||
data["labels"] = torch.tensor([[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], dtype=torch.int64)
|
||||
return data
|
||||
|
||||
|
||||
# output transform function
|
||||
output_transform_fn = lambda x: x
|
||||
|
||||
|
@ -47,6 +55,7 @@ output_transform_fn = lambda x: x
|
|||
loss_fn_for_t5_model = lambda x: x["last_hidden_state"].mean()
|
||||
loss_fn_for_encoder_only = lambda x: x["last_hidden_state"].mean()
|
||||
loss_fn_for_conditional_generation = lambda x: x["loss"]
|
||||
loss_fn_for_token_classification = lambda x: x["loss"]
|
||||
|
||||
# define model config
|
||||
config = transformers.T5Config(d_model=128, num_layers=2, dropout_rate=0, decoder_start_token_id=0)
|
||||
|
@ -79,3 +88,11 @@ model_zoo.register(
|
|||
loss_fn=loss_fn_for_encoder_only,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
model_zoo.register(
|
||||
name="transformers_t5_for_token_classification",
|
||||
model_fn=lambda: transformers.T5ForTokenClassification(config),
|
||||
data_gen_fn=data_gen_for_token_classification,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_token_classification,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
|
|
|
@ -15,8 +15,7 @@ WORLD_SIZE = 2
|
|||
def check_p2p_communication():
|
||||
pg_mesh = ProcessGroupMesh(WORLD_SIZE)
|
||||
stage_manager = PipelineStageManager(pg_mesh, 0)
|
||||
p2p = PipelineP2PCommunication(stage_manager)
|
||||
|
||||
p2p = PipelineP2PCommunication(stage_manager, overlap_p2p=False)
|
||||
rank = dist.get_rank()
|
||||
|
||||
tensor = torch.ones(1, device=get_accelerator().get_current_device())
|
||||
|
@ -31,41 +30,40 @@ def check_p2p_communication():
|
|||
for obj in data:
|
||||
p2p.send_forward(obj)
|
||||
for i in range(len(data)):
|
||||
recv_obj = p2p.send_forward_recv_backward(data[i], send_prior_fallback=False)
|
||||
recv_obj, _ = p2p.send_forward_recv_backward(data[i], send_first=False)
|
||||
assert recv_obj == data[-(i + 1)]
|
||||
elif rank == 1:
|
||||
for obj in data:
|
||||
recv_obj = p2p.recv_forward()
|
||||
recv_obj, _ = p2p.recv_forward()
|
||||
assert recv_obj == obj
|
||||
for i in range(len(data)):
|
||||
p2p.send_backward(data[-(i + 1)])
|
||||
recv_obj = p2p.recv_forward()
|
||||
recv_obj, _ = p2p.recv_forward()
|
||||
assert recv_obj == data[i]
|
||||
|
||||
if rank == 1:
|
||||
for obj in data:
|
||||
p2p.send_backward(obj)
|
||||
for i in range(len(data)):
|
||||
recv_obj = p2p.send_backward_recv_forward(data[i], send_prior_fallback=True)
|
||||
recv_obj, _ = p2p.send_backward_recv_forward(data[i], send_first=True)
|
||||
assert recv_obj == data[-(i + 1)]
|
||||
elif rank == 0:
|
||||
for obj in data:
|
||||
recv_obj = p2p.recv_backward()
|
||||
recv_obj, _ = p2p.recv_backward()
|
||||
assert recv_obj == obj
|
||||
for i in range(len(data)):
|
||||
recv_obj = p2p.recv_backward()
|
||||
p2p.send_forward(data[-(i + 1)])
|
||||
recv_obj, _ = p2p.send_forward_recv_backward(data[-(i + 1)], send_first=False)
|
||||
assert recv_obj == data[i]
|
||||
|
||||
if rank == 0:
|
||||
recv_obj = p2p.send_forward_recv_backward(
|
||||
recv_obj, _ = p2p.send_forward_recv_backward(
|
||||
tensor,
|
||||
send_metadata=False,
|
||||
metadata_recv=create_send_metadata(tensor),
|
||||
)
|
||||
assert recv_obj == tensor
|
||||
elif rank == 1:
|
||||
recv_obj = p2p.recv_forward(metadata_recv=create_send_metadata(tensor))
|
||||
recv_obj, _ = p2p.recv_forward(metadata_recv=create_send_metadata(tensor))
|
||||
assert recv_obj == tensor
|
||||
p2p.send_backward(tensor, send_metadata=False)
|
||||
|
||||
|
|
|
@ -52,7 +52,7 @@ def check_stage_manager():
|
|||
# check p2p groups
|
||||
for prev, cur in zip(ranks_in_group[:-1], ranks_in_group[1:]):
|
||||
if rank in [prev, cur]:
|
||||
group = stage_manager.get_p2p_process_group(prev, cur)
|
||||
group = stage_manager.get_p2p_process_group()
|
||||
dist.barrier(group=group)
|
||||
|
||||
# check stage groups
|
||||
|
|
|
@ -41,14 +41,17 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
t5 = unwrap_model(org_model)
|
||||
sharded_t5 = unwrap_model(sharded_model)
|
||||
|
||||
row_layer_for_check = ["shared", "encoder.block[0].layer[0].SelfAttention.q"]
|
||||
if t5.__class__.__name__ == "T5ForTokenClassification":
|
||||
row_layer_for_check = ["transformer.shared", "transformer.encoder.block[0].layer[0].SelfAttention.q"]
|
||||
else:
|
||||
row_layer_for_check = ["shared", "encoder.block[0].layer[0].SelfAttention.q"]
|
||||
|
||||
# Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
|
||||
grads_to_check = {}
|
||||
if test_config["precision"] == "fp32":
|
||||
atol, rtol = 1e-5, 1e-3
|
||||
else:
|
||||
atol, rtol = 5e-3, 5e-3
|
||||
atol, rtol = 5e-2, 5e-2
|
||||
if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
|
||||
row_layer_grads = get_grad_tensors_for_check(
|
||||
t5, sharded_t5, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0
|
||||
|
@ -66,7 +69,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
else:
|
||||
atol, rtol = 5e-3, 5e-3
|
||||
|
||||
if org_model.__class__.__name__ != "T5ForConditionalGeneration":
|
||||
if org_model.__class__.__name__ not in ["T5ForConditionalGeneration", "T5ForTokenClassification"]:
|
||||
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
|
||||
|
||||
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
|
||||
|
@ -157,7 +160,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
)
|
||||
@clear_cache_before_run()
|
||||
def run_t5_test(test_config):
|
||||
sub_model_zoo = model_zoo.get_sub_registry("transformers_t5")
|
||||
sub_model_zoo = model_zoo.get_sub_registry(["transformers_t5_for_token_classification"])
|
||||
|
||||
for name, (
|
||||
model_fn,
|
||||
|
@ -167,7 +170,10 @@ def run_t5_test(test_config):
|
|||
_,
|
||||
) in sub_model_zoo.items():
|
||||
# skip 4-stage pp test for t5_encoder
|
||||
if test_config["pp_size"] > 2 and name == "transformers_t5_encoder_model":
|
||||
if test_config["pp_size"] > 2 and name in [
|
||||
"transformers_t5_encoder_model",
|
||||
"transformers_t5_for_token_classification",
|
||||
]:
|
||||
continue
|
||||
|
||||
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
||||
|
|
|
@ -1 +1 @@
|
|||
0.3.9
|
||||
0.4.0
|
||||
|
|
Loading…
Reference in New Issue