[pipeline,shardformer] Fix p2p efficiency in pipeline, allow skipping loading weight not in weight_map when `strict=False`, fix llama flash attention forward, add flop estimation by megatron in llama benchmark (#5017)

* Use p2p

* Cannot bidirectonal send p2p

* Refactor tensor creation and serialization in P2P
communication

* Fix llama forward args in flash attention

* Add flop estimate from megatron

* Support loading weight not in weight_map when strict=False in hybrid_parallel

* Use send_forward_recv_backward, etc in 1f1b

* Use dataclass for metdata
Remove torch.cuda.synchronize() as suggested

* Add comment about the torch.cuda.synchronize for potential error

* Typo

* Update hybrid_parallel_checkpoint_io.py

* Update p2p.py

* Update one_f_one_b.py

* Update p2p.py

---------

Co-authored-by: flybird11111 <1829166702@qq.com>
pull/5060/head
Elsa Granger 1 year ago committed by GitHub
parent 28052a71fb
commit b2ad0d9e8f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,4 +1,5 @@
import copy
from functools import reduce
import logging
import os
from pathlib import Path
@ -313,9 +314,13 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
# Keep a record of loaded files so that file will not be repeatedly loaded.
loaded_file = set()
missing_keys = []
missing_file_keys = []
def _load(name: str):
if name not in weight_map:
raise ValueError(f"{name} is not stored in checkpoint, please check your checkpointing configuration!")
missing_file_keys.append(name)
return
filename = weight_map[name]
# If this param/buffer has been loaded before, directly return.
@ -324,7 +329,6 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
file_path = os.path.join(ckpt_root_path, filename)
state_dict = load_shard_state_dict(Path(file_path), use_safetensors)
missing_keys = []
load_state_dict_into_model(
model, state_dict, missing_keys=missing_keys, strict=strict, load_sub_module=True
@ -357,6 +361,27 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
if self.verbose and self.coordinator.is_master():
logging.info(f"The model has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
if len(missing_keys) == 0:
raise RuntimeError(
"No weigth is loaded into the model. Please check the checkpoint files and the model structure."
)
remain_keys = reduce(lambda a, b: a & b, map(set, missing_keys))
remain_keys = remain_keys.union(set(missing_file_keys))
if len(remain_keys) > 0:
if strict:
error_msgs = "Missing key(s) in state_dict: {}. ".format(
", ".join('"{}"'.format(k) for k in missing_keys)
)
raise RuntimeError(
"Error(s) in loading state_dict for {}:\n\t{}".format(
self.__class__.__name__, "\n\t".join(error_msgs)
)
)
else:
if self.coordinator.is_master():
logging.info(f"The following keys are not loaded from checkpoint: {remain_keys}")
def save_sharded_optimizer(
self,
optimizer: OptimizerWrapper,

@ -5,9 +5,12 @@ import io
import pickle
import re
from typing import Any, List, Optional, Union
from collections import namedtuple
import torch
import torch.distributed as dist
from dataclasses import dataclass
from enum import Enum
from packaging.version import Version
from torch.distributed import ProcessGroup
from torch.distributed import distributed_c10d as c10d
@ -45,6 +48,21 @@ def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) -
return unpickle
def check_for_nccl_backend(group):
pg = group or c10d._get_default_group()
# Gate PG wrapper check on Gloo availability.
if c10d._GLOO_AVAILABLE:
# It is not expected for PG to be wrapped many times, but support it just
# in case
while isinstance(pg, c10d._ProcessGroupWrapper):
pg = pg.wrapped_pg
return (
c10d.is_nccl_available() and
pg.name() == c10d.Backend.NCCL
)
def _broadcast_object_list(
object_list: List[Any], src: int, group: ProcessGroup, device: Optional[Union[torch.device, str, int]] = None
):
@ -65,7 +83,7 @@ def _broadcast_object_list(
c10d._warn_not_in_group("broadcast_object_list")
return
is_nccl_backend = c10d._check_for_nccl_backend(group)
is_nccl_backend = check_for_nccl_backend(group)
current_device = None
if device is not None:
@ -113,7 +131,7 @@ def _broadcast_object_list(
if my_rank != src:
for i, obj_size in enumerate(object_sizes_tensor):
obj_view = object_tensor[offset : offset + obj_size]
obj_view = object_tensor[offset: offset + obj_size]
obj_view = obj_view.type(torch.uint8)
if obj_view.device != torch.device("cpu"):
obj_view = obj_view.cpu()
@ -131,6 +149,258 @@ def _broadcast_object_list(
object_list[i] = unpickle_object
def check_device(group):
is_nccl_backend = check_for_nccl_backend(group)
current_device = None
current_device = torch.device("cpu")
if is_nccl_backend:
current_device = torch.device("cuda", torch.cuda.current_device())
return current_device, is_nccl_backend
TensorMetadata = namedtuple('TensorMetadata', ['key', 'shape', 'dtype', 'requires_grad'])
class P2PDataType(Enum):
serialization = 0
tensor = 1
list = 2
dict = 3
@dataclass
class P2PMetadata:
data_type: P2PDataType
content: Union[List[TensorMetadata], TensorMetadata, Any]
def filling_ops_queue(obj, comm_op, comm_rank, ops_queue, group):
if isinstance(obj, torch.Tensor):
obj = obj.contiguous()
op_to_add = dist.P2POp(comm_op, obj, comm_rank, group)
ops_queue.append(op_to_add)
else:
for tensor_to_comm in obj:
tensor_to_comm = tensor_to_comm.contiguous()
op_to_add = dist.P2POp(comm_op, tensor_to_comm, comm_rank, group)
ops_queue.append(op_to_add)
def create_recv_buffer(p2p_metadata: P2PMetadata, current_device):
if p2p_metadata.data_type == P2PDataType.tensor:
metadata = p2p_metadata.content
tensor_recv = torch.empty(metadata.shape, requires_grad=metadata.requires_grad, device=current_device, dtype=metadata.dtype)
return tensor_recv
elif p2p_metadata.data_type in (P2PDataType.list, P2PDataType.dict):
buffer_recv = []
for metadata in p2p_metadata.content:
tensor_recv = torch.empty(metadata.shape, requires_grad=metadata.requires_grad, device=current_device, dtype=metadata.dtype)
buffer_recv.append(tensor_recv)
return buffer_recv
else:
raise ValueError(f"Unknown data_type: {p2p_metadata.data_type}")
def _batch_send_recv_tensor(send_tensor_list, recv_tensor_metadata, send_dst, recv_src, send_group, recv_group, current_device):
buffer_recv = None
if recv_tensor_metadata is not None:
buffer_recv = create_recv_buffer(recv_tensor_metadata, current_device)
ops = []
if send_dst is not None:
filling_ops_queue(send_tensor_list, dist.isend, send_dst, ops, send_group)
if recv_src is not None:
assert buffer_recv is not None
filling_ops_queue(buffer_recv, dist.irecv, recv_src, ops, recv_group)
if len(ops) > 0:
reqs = dist.batch_isend_irecv(ops)
for req in reqs:
req.wait()
torch.cuda.synchronize()
# Remove synchronization according to Pytorch's documentation
# However, the Megatron-LM does synchronization here
# https://github.com/microsoft/Megatron-DeepSpeed/blob/ef13d099c2a1609225a4ce4c1a1753cc76dd90a1/megatron/p2p_communication.py#L111-L112
# In case there is potential error, uncomment the following `torch.cuda.synchronize()`
# torch.cuda.synchronize()
return buffer_recv
def _send_recv_serialization_object(
object: Any,
send_dst: Optional[int], recv_src: Optional[int],
send_group: Optional[ProcessGroup], recv_group: Optional[ProcessGroup],
current_device,
is_nccl_backend):
ops = []
send_object_tensor = None
if object is not None and send_dst is not None:
if Version(torch.__version__) >= Version("1.13.0"):
send_object_tensor, send_object_size_tensor = c10d._object_to_tensor(object, device=current_device)
else:
send_object_tensor, send_object_size_tensor = c10d._object_to_tensor(object)
if is_nccl_backend:
send_object_size_tensor = send_object_size_tensor.to(current_device)
send_object_tensor = send_object_tensor.to(current_device)
filling_ops_queue(send_object_size_tensor, dist.isend, send_dst, ops, send_group)
recv_object_size_tensor = None
if recv_src is not None:
recv_object_size_tensor = torch.empty(1, dtype=torch.long)
if is_nccl_backend:
recv_object_size_tensor = recv_object_size_tensor.to(current_device)
filling_ops_queue(recv_object_size_tensor, dist.irecv, recv_src, ops, recv_group)
if len(ops) > 0:
reqs = dist.batch_isend_irecv(ops)
for req in reqs:
req.wait()
torch.cuda.synchronize()
# See the comment in `_batch_send_recv_tensor`
# torch.cuda.synchronize()
ops = []
if send_dst is not None and send_object_tensor is not None:
filling_ops_queue(send_object_tensor, dist.isend, send_dst, ops, send_group)
recv_object_tensor = None
if recv_src is not None and recv_object_size_tensor is not None:
recv_object_tensor = torch.empty(recv_object_size_tensor.item(), dtype=torch.uint8)
if is_nccl_backend:
recv_object_tensor = recv_object_tensor.to(current_device)
filling_ops_queue(recv_object_tensor, dist.irecv, recv_src, ops, recv_group)
if len(ops) > 0:
reqs = dist.batch_isend_irecv(ops)
for req in reqs:
req.wait()
torch.cuda.synchronize()
# See the comment in `_batch_send_recv_tensor`
# torch.cuda.synchronize()
if recv_object_tensor is not None and recv_object_size_tensor is not None:
recv_object_tensor = recv_object_tensor.type(torch.uint8)
if recv_object_tensor.device != torch.device("cpu"):
recv_object_tensor = recv_object_tensor.cpu()
unpickle_object = _cuda_safe_tensor_to_object(
recv_object_tensor, recv_object_size_tensor.item())
if (
isinstance(unpickle_object, torch.Tensor)
and unpickle_object.device.index != torch.cuda.current_device()
):
unpickle_object = unpickle_object.cuda()
return unpickle_object
def _check_if_fast_send_available(object):
if type(object) is torch.Tensor:
return True
elif type(object) is list:
is_list_of_tensor = all([type(v) is torch.Tensor for v in object])
return is_list_of_tensor
elif type(object) is dict:
is_dict_of_tensor = all([type(k) is str and type(
v) is torch.Tensor for k, v in object.items()])
return is_dict_of_tensor
return False
def _communicate(
object,
send_dst: Optional[int],
recv_src: Optional[int],
send_group: Optional[ProcessGroup] = None,
recv_group: Optional[ProcessGroup] = None,
) -> Any:
if c10d._rank_not_in_group(send_group) or c10d._rank_not_in_group(recv_group):
c10d._warn_not_in_group("_communicate")
return
current_send_device, is_send_nccl_backend = check_device(send_group)
current_recv_device, is_recv_nccl_backend = check_device(recv_group)
is_nccl_backend = is_send_nccl_backend and is_recv_nccl_backend
assert current_send_device == current_recv_device
current_device = current_send_device
assert (send_dst is not None) or (recv_src is not None)
can_fast_send = False
send_metadata = None
if send_dst is not None:
can_fast_send = _check_if_fast_send_available(object) and is_nccl_backend
if not can_fast_send:
send_metadata = P2PMetadata(P2PDataType.serialization, object)
else:
if type(object) is torch.Tensor:
data_type = P2PDataType.tensor
content = TensorMetadata(None, object.shape, object.dtype, object.requires_grad)
elif type(object) is list:
data_type = P2PDataType.list
content = []
for v in object:
content.append(TensorMetadata(None, v.shape, v.dtype, v.requires_grad))
elif type(object) is dict:
data_type = P2PDataType.dict
content = []
for k, v in object.items():
content.append(TensorMetadata(k, v.shape, v.dtype, v.requires_grad))
else:
raise ValueError('Cannot send object of type {}'.format(type(object)))
send_metadata = P2PMetadata(data_type, content)
recv_metadata = _send_recv_serialization_object(send_metadata, send_dst, recv_src, send_group, recv_group, current_device, is_nccl_backend)
if recv_metadata is not None:
assert type(recv_metadata) is P2PMetadata
if recv_metadata.data_type == P2PDataType.serialization:
return recv_metadata.content
if not can_fast_send and send_dst is not None:
return
send_tensor_list = None
if type(object) is torch.Tensor:
send_tensor_list = object
elif type(object) is list:
send_tensor_list = object
elif type(object) is dict:
send_tensor_list = list(object.values())
recv_buffer = _batch_send_recv_tensor(send_tensor_list, recv_metadata, send_dst, recv_src, send_group, recv_group, current_device)
if recv_metadata is not None:
assert recv_buffer is not None
if recv_metadata.data_type in [P2PDataType.tensor, P2PDataType.list]:
return recv_buffer
elif recv_metadata.data_type == P2PDataType.dict:
return {
k: v
for k, v in zip(
[m.key for m in recv_metadata.content],
recv_buffer,
)
}
else:
raise ValueError('Unknown data type {}'.format(recv_metadata.data_type))
def _send_object(object: Any, src: int, dst: int, group: ProcessGroup) -> None:
"""send anything to dst rank
@ -141,8 +411,7 @@ def _send_object(object: Any, src: int, dst: int, group: ProcessGroup) -> None:
Returns:
None
"""
# then broadcast safely
_broadcast_object_list([object], src, group)
_communicate(object, send_dst=dst, recv_src=None, send_group=group)
def _recv_object(src: int, dst: int, group: ProcessGroup) -> Any:
@ -154,10 +423,7 @@ def _recv_object(src: int, dst: int, group: ProcessGroup) -> Any:
Returns:
Any: Object received from src.
"""
object_list = [None]
_broadcast_object_list(object_list, src, group)
return object_list[0]
return _communicate(None, send_dst=None, recv_src=src, recv_group=group)
def _p2p_comm(
@ -302,6 +568,64 @@ class PipelineP2PCommunication:
cur_rank = self.stage_manager.get_rank()
_send_object(input_object, cur_rank, prev_rank, self.stage_manager.get_p2p_process_group(cur_rank, prev_rank))
def send_forward_recv_backward(self, input_object: Any, next_rank: int = None) -> Any:
"""Sends the gradient tensor to and copy the gradient tensor from the next stage in pipeline
Args:
input_object (Any): Object to be sent.
next_rank (int, optional): The rank of the sender and recipient of the tensor
"""
if next_rank is None:
next_rank = self.stage_manager.get_next_rank()
cur_rank = self.stage_manager.get_rank()
group = self.stage_manager.get_p2p_process_group(cur_rank, next_rank)
return _communicate(
input_object, next_rank, next_rank,
send_group=group, recv_group=group,
)
def send_backward_recv_forward(self, input_object: Any, prev_rank: int = None) -> Any:
"""Sends the gradient tensor to and copy the gradient tensor from the previous stage in pipeline
Args:
input_object (Any): Object to be sent.
prev_rank (int, optional): The rank of the sender and recipient of the tensor
"""
if prev_rank is None:
prev_rank = self.stage_manager.get_prev_rank()
cur_rank = self.stage_manager.get_rank()
group = self.stage_manager.get_p2p_process_group(prev_rank, cur_rank)
return _communicate(
input_object, prev_rank, prev_rank,
send_group=group, recv_group=group,
)
def send_forward_recv_forward(self, input_object: Any, prev_rank: int = None, next_rank: int = None) -> Any:
"""Sends the gradient tensor to the previous stage and copy the input tensor from the previous stage in pipeline.
Args:
input_object (Any): Object to be sent.
prev_rank (int, optional): The rank of the sender of the tensor
next_rank (int, optional): The rank of the recipient of the tensor
"""
if prev_rank is None:
prev_rank = self.stage_manager.get_prev_rank()
if next_rank is None:
next_rank = self.stage_manager.get_next_rank()
cur_rank = self.stage_manager.get_rank()
recv_group = self.stage_manager.get_p2p_process_group(prev_rank, cur_rank)
send_group = self.stage_manager.get_p2p_process_group(cur_rank, next_rank)
return _communicate(
input_object,
send_dst=next_rank,
recv_src=prev_rank,
send_group=send_group,
recv_group=recv_group,
)
def p2p_communicate(
self, output_object: Any, recv_pre: bool, peer: int = None, comm_dtype: torch.dtype = torch.float16
) -> None:

@ -127,6 +127,17 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
if not self.stage_manager.is_last_stage():
self.comm.send_forward(output_object, next_rank)
def send_forward_recv_backward(self, output_object: Any, next_rank: int = None) -> Any:
"""Sends the input tensor to the next stage and copy the gradient tensor from the next stage in pipeline.
For 1F1B.
Args:
output_object (Any): Object to be sent.
next_rank (int, optional): The rank of the recipient of the tensor.
"""
if not self.stage_manager.is_last_stage():
return self.comm.send_forward_recv_backward(output_object, next_rank)
def send_backward(self, input_object: Any, prev_rank: int = None) -> None:
"""Sends the gradient tensor to the previous stage in pipeline.
For 1F1B.
@ -138,6 +149,33 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
if not self.stage_manager.is_first_stage():
self.comm.send_backward(input_object, prev_rank)
def send_backward_recv_forward(self, output_object: Any, prev_rank: int = None) -> Any:
"""Sends the gradient tensor to the previous stage and copy the input tensor from the previous stage in pipeline.
For 1F1B.
Args:
output_object (Any): Object to be sent.
prev_rank (int, optional): The rank of the recipient of the tensor.
"""
if not self.stage_manager.is_first_stage():
return self.comm.send_backward_recv_forward(output_object, prev_rank)
def send_forward_recv_forward(self, input_object: Any, prev_rank: int = None, next_rank: int = None) -> Any:
"""Sends the input tensor to the next stage and copy the input tensor from the previous stage in pipeline.
For 1F1B.
Args:
input_object (Any): Object to be sent.
prev_rank (int, optional): The previous rank of the recipient of the tensor.
next_rank (int, optional): The next rank of the recipient of the tensor.
"""
if self.stage_manager.is_first_stage():
return self.comm.send_forward(input_object, next_rank)
elif self.stage_manager.is_last_stage():
return self.comm.recv_forward(prev_rank)
else:
return self.comm.send_forward_recv_forward(input_object, prev_rank, next_rank)
def forward_step(
self,
model: Module,
@ -291,7 +329,6 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
if not last_iteration:
input_obj = self.recv_forward()
else:
# TODO adjust here
self.send_forward(output_obj)

@ -413,6 +413,7 @@ def get_llama_flash_attention_forward():
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
assert q_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4."

@ -183,7 +183,11 @@ def main():
model_numel = get_model_numel(model)
coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}")
performance_evaluator = PerformanceEvaluator(
model_numel, args.grad_checkpoint, args.ignore_steps, dp_world_size=dp_size
model_numel,
model.config.num_hidden_layers,
model.config.hidden_size,
model.config.vocab_size,
args.grad_checkpoint, args.ignore_steps, dp_world_size=dp_size
)
optimizer = HybridAdam(model.parameters())

@ -58,6 +58,9 @@ class PerformanceEvaluator:
def __init__(
self,
model_numel: int,
num_layers: int,
hidden_size: int,
vocab_size: int,
enable_grad_checkpoint: bool = False,
ignore_steps: int = 0,
dp_world_size: Optional[int] = None,
@ -65,12 +68,16 @@ class PerformanceEvaluator:
self.model_numel = model_numel
self.enable_grad_checkpoint = enable_grad_checkpoint
self.ignore_steps = ignore_steps
self.num_layers = num_layers
self.hidden_size = hidden_size
self.vocab_size = vocab_size
self.coordinator = DistCoordinator()
self.dp_world_size = dp_world_size or self.coordinator.world_size
self.disable: bool = False
self.timer = Timer()
self.num_samples: int = 0
self.flop_megatron = 0
self.flop: int = 0
def on_step_start(self, step: int) -> None:
@ -89,17 +96,20 @@ class PerformanceEvaluator:
batch_size, seq_len = input_ids.shape
self.num_samples += batch_size
checkpoint_activations_factor = (3 + int(self.enable_grad_checkpoint))
self.flop_megatron += (24 * checkpoint_activations_factor * batch_size * seq_len * self.num_layers * (self.hidden_size**2)) * (1. + (seq_len / (6. * self.hidden_size)) + (self.vocab_size / (16. * self.num_layers * self.hidden_size)))
self.flop += batch_size * seq_len * self.model_numel * 2 * (3 + int(self.enable_grad_checkpoint))
def on_fit_end(self) -> None:
avg_duration = all_reduce_mean(self.timer.duration, self.coordinator.world_size)
avg_throughput = self.num_samples * self.dp_world_size / (avg_duration + 1e-12)
mp_world_size = self.coordinator.world_size // self.dp_world_size
avg_tflops_per_gpu_megatron = self.flop_megatron / 1e12 / (avg_duration + 1e-12) / mp_world_size
avg_tflops_per_gpu = self.flop / 1e12 / (avg_duration + 1e-12) / mp_world_size
self.coordinator.print_on_master(
f"num_samples: {self.num_samples}, dp_world_size: {self.dp_world_size}, flop: {self.flop}, avg_duration: {avg_duration}, "
f"num_samples: {self.num_samples}, dp_world_size: {self.dp_world_size}, flop_megatron: {self.flop_megatron}, flop: {self.flop}, avg_duration: {avg_duration}, "
f"avg_throughput: {avg_throughput}"
)
self.coordinator.print_on_master(
f"Throughput: {avg_throughput:.2f} samples/sec, TFLOPS per GPU: {avg_tflops_per_gpu:.2f}"
f"Throughput: {avg_throughput:.2f} samples/sec, TFLOPS per GPU by Megatron: {avg_tflops_per_gpu_megatron:.2f}, TFLOPS per GPU: {avg_tflops_per_gpu:.2f}"
)

Loading…
Cancel
Save