mirror of https://github.com/hpcaitech/ColossalAI
[pipeline,shardformer] Fix p2p efficiency in pipeline, allow skipping loading weight not in weight_map when `strict=False`, fix llama flash attention forward, add flop estimation by megatron in llama benchmark (#5017)
* Use p2p * Cannot bidirectonal send p2p * Refactor tensor creation and serialization in P2P communication * Fix llama forward args in flash attention * Add flop estimate from megatron * Support loading weight not in weight_map when strict=False in hybrid_parallel * Use send_forward_recv_backward, etc in 1f1b * Use dataclass for metdata Remove torch.cuda.synchronize() as suggested * Add comment about the torch.cuda.synchronize for potential error * Typo * Update hybrid_parallel_checkpoint_io.py * Update p2p.py * Update one_f_one_b.py * Update p2p.py --------- Co-authored-by: flybird11111 <1829166702@qq.com>pull/5060/head
parent
28052a71fb
commit
b2ad0d9e8f
|
@ -1,4 +1,5 @@
|
|||
import copy
|
||||
from functools import reduce
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
@ -313,9 +314,13 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
# Keep a record of loaded files so that file will not be repeatedly loaded.
|
||||
loaded_file = set()
|
||||
|
||||
missing_keys = []
|
||||
missing_file_keys = []
|
||||
|
||||
def _load(name: str):
|
||||
if name not in weight_map:
|
||||
raise ValueError(f"{name} is not stored in checkpoint, please check your checkpointing configuration!")
|
||||
missing_file_keys.append(name)
|
||||
return
|
||||
filename = weight_map[name]
|
||||
|
||||
# If this param/buffer has been loaded before, directly return.
|
||||
|
@ -324,7 +329,6 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
|
||||
file_path = os.path.join(ckpt_root_path, filename)
|
||||
state_dict = load_shard_state_dict(Path(file_path), use_safetensors)
|
||||
missing_keys = []
|
||||
|
||||
load_state_dict_into_model(
|
||||
model, state_dict, missing_keys=missing_keys, strict=strict, load_sub_module=True
|
||||
|
@ -357,6 +361,27 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
if self.verbose and self.coordinator.is_master():
|
||||
logging.info(f"The model has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
|
||||
|
||||
if len(missing_keys) == 0:
|
||||
raise RuntimeError(
|
||||
"No weigth is loaded into the model. Please check the checkpoint files and the model structure."
|
||||
)
|
||||
|
||||
remain_keys = reduce(lambda a, b: a & b, map(set, missing_keys))
|
||||
remain_keys = remain_keys.union(set(missing_file_keys))
|
||||
if len(remain_keys) > 0:
|
||||
if strict:
|
||||
error_msgs = "Missing key(s) in state_dict: {}. ".format(
|
||||
", ".join('"{}"'.format(k) for k in missing_keys)
|
||||
)
|
||||
raise RuntimeError(
|
||||
"Error(s) in loading state_dict for {}:\n\t{}".format(
|
||||
self.__class__.__name__, "\n\t".join(error_msgs)
|
||||
)
|
||||
)
|
||||
else:
|
||||
if self.coordinator.is_master():
|
||||
logging.info(f"The following keys are not loaded from checkpoint: {remain_keys}")
|
||||
|
||||
def save_sharded_optimizer(
|
||||
self,
|
||||
optimizer: OptimizerWrapper,
|
||||
|
|
|
@ -5,9 +5,12 @@ import io
|
|||
import pickle
|
||||
import re
|
||||
from typing import Any, List, Optional, Union
|
||||
from collections import namedtuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from packaging.version import Version
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.distributed import distributed_c10d as c10d
|
||||
|
@ -45,6 +48,21 @@ def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) -
|
|||
return unpickle
|
||||
|
||||
|
||||
def check_for_nccl_backend(group):
|
||||
pg = group or c10d._get_default_group()
|
||||
# Gate PG wrapper check on Gloo availability.
|
||||
if c10d._GLOO_AVAILABLE:
|
||||
# It is not expected for PG to be wrapped many times, but support it just
|
||||
# in case
|
||||
while isinstance(pg, c10d._ProcessGroupWrapper):
|
||||
pg = pg.wrapped_pg
|
||||
|
||||
return (
|
||||
c10d.is_nccl_available() and
|
||||
pg.name() == c10d.Backend.NCCL
|
||||
)
|
||||
|
||||
|
||||
def _broadcast_object_list(
|
||||
object_list: List[Any], src: int, group: ProcessGroup, device: Optional[Union[torch.device, str, int]] = None
|
||||
):
|
||||
|
@ -65,7 +83,7 @@ def _broadcast_object_list(
|
|||
c10d._warn_not_in_group("broadcast_object_list")
|
||||
return
|
||||
|
||||
is_nccl_backend = c10d._check_for_nccl_backend(group)
|
||||
is_nccl_backend = check_for_nccl_backend(group)
|
||||
current_device = None
|
||||
|
||||
if device is not None:
|
||||
|
@ -131,6 +149,258 @@ def _broadcast_object_list(
|
|||
object_list[i] = unpickle_object
|
||||
|
||||
|
||||
def check_device(group):
|
||||
is_nccl_backend = check_for_nccl_backend(group)
|
||||
current_device = None
|
||||
|
||||
current_device = torch.device("cpu")
|
||||
if is_nccl_backend:
|
||||
current_device = torch.device("cuda", torch.cuda.current_device())
|
||||
return current_device, is_nccl_backend
|
||||
|
||||
|
||||
TensorMetadata = namedtuple('TensorMetadata', ['key', 'shape', 'dtype', 'requires_grad'])
|
||||
|
||||
|
||||
class P2PDataType(Enum):
|
||||
serialization = 0
|
||||
tensor = 1
|
||||
list = 2
|
||||
dict = 3
|
||||
|
||||
|
||||
@dataclass
|
||||
class P2PMetadata:
|
||||
data_type: P2PDataType
|
||||
content: Union[List[TensorMetadata], TensorMetadata, Any]
|
||||
|
||||
|
||||
def filling_ops_queue(obj, comm_op, comm_rank, ops_queue, group):
|
||||
if isinstance(obj, torch.Tensor):
|
||||
obj = obj.contiguous()
|
||||
op_to_add = dist.P2POp(comm_op, obj, comm_rank, group)
|
||||
ops_queue.append(op_to_add)
|
||||
else:
|
||||
for tensor_to_comm in obj:
|
||||
tensor_to_comm = tensor_to_comm.contiguous()
|
||||
op_to_add = dist.P2POp(comm_op, tensor_to_comm, comm_rank, group)
|
||||
ops_queue.append(op_to_add)
|
||||
|
||||
|
||||
def create_recv_buffer(p2p_metadata: P2PMetadata, current_device):
|
||||
if p2p_metadata.data_type == P2PDataType.tensor:
|
||||
metadata = p2p_metadata.content
|
||||
tensor_recv = torch.empty(metadata.shape, requires_grad=metadata.requires_grad, device=current_device, dtype=metadata.dtype)
|
||||
return tensor_recv
|
||||
elif p2p_metadata.data_type in (P2PDataType.list, P2PDataType.dict):
|
||||
buffer_recv = []
|
||||
for metadata in p2p_metadata.content:
|
||||
tensor_recv = torch.empty(metadata.shape, requires_grad=metadata.requires_grad, device=current_device, dtype=metadata.dtype)
|
||||
buffer_recv.append(tensor_recv)
|
||||
return buffer_recv
|
||||
else:
|
||||
raise ValueError(f"Unknown data_type: {p2p_metadata.data_type}")
|
||||
|
||||
|
||||
def _batch_send_recv_tensor(send_tensor_list, recv_tensor_metadata, send_dst, recv_src, send_group, recv_group, current_device):
|
||||
buffer_recv = None
|
||||
if recv_tensor_metadata is not None:
|
||||
buffer_recv = create_recv_buffer(recv_tensor_metadata, current_device)
|
||||
|
||||
ops = []
|
||||
|
||||
if send_dst is not None:
|
||||
filling_ops_queue(send_tensor_list, dist.isend, send_dst, ops, send_group)
|
||||
|
||||
if recv_src is not None:
|
||||
assert buffer_recv is not None
|
||||
filling_ops_queue(buffer_recv, dist.irecv, recv_src, ops, recv_group)
|
||||
|
||||
if len(ops) > 0:
|
||||
reqs = dist.batch_isend_irecv(ops)
|
||||
for req in reqs:
|
||||
req.wait()
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Remove synchronization according to Pytorch's documentation
|
||||
# However, the Megatron-LM does synchronization here
|
||||
# https://github.com/microsoft/Megatron-DeepSpeed/blob/ef13d099c2a1609225a4ce4c1a1753cc76dd90a1/megatron/p2p_communication.py#L111-L112
|
||||
# In case there is potential error, uncomment the following `torch.cuda.synchronize()`
|
||||
# torch.cuda.synchronize()
|
||||
|
||||
return buffer_recv
|
||||
|
||||
|
||||
def _send_recv_serialization_object(
|
||||
object: Any,
|
||||
send_dst: Optional[int], recv_src: Optional[int],
|
||||
send_group: Optional[ProcessGroup], recv_group: Optional[ProcessGroup],
|
||||
current_device,
|
||||
is_nccl_backend):
|
||||
ops = []
|
||||
send_object_tensor = None
|
||||
if object is not None and send_dst is not None:
|
||||
if Version(torch.__version__) >= Version("1.13.0"):
|
||||
send_object_tensor, send_object_size_tensor = c10d._object_to_tensor(object, device=current_device)
|
||||
else:
|
||||
send_object_tensor, send_object_size_tensor = c10d._object_to_tensor(object)
|
||||
|
||||
if is_nccl_backend:
|
||||
send_object_size_tensor = send_object_size_tensor.to(current_device)
|
||||
send_object_tensor = send_object_tensor.to(current_device)
|
||||
|
||||
filling_ops_queue(send_object_size_tensor, dist.isend, send_dst, ops, send_group)
|
||||
|
||||
recv_object_size_tensor = None
|
||||
if recv_src is not None:
|
||||
recv_object_size_tensor = torch.empty(1, dtype=torch.long)
|
||||
if is_nccl_backend:
|
||||
recv_object_size_tensor = recv_object_size_tensor.to(current_device)
|
||||
filling_ops_queue(recv_object_size_tensor, dist.irecv, recv_src, ops, recv_group)
|
||||
|
||||
if len(ops) > 0:
|
||||
reqs = dist.batch_isend_irecv(ops)
|
||||
for req in reqs:
|
||||
req.wait()
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# See the comment in `_batch_send_recv_tensor`
|
||||
# torch.cuda.synchronize()
|
||||
|
||||
ops = []
|
||||
|
||||
if send_dst is not None and send_object_tensor is not None:
|
||||
filling_ops_queue(send_object_tensor, dist.isend, send_dst, ops, send_group)
|
||||
|
||||
recv_object_tensor = None
|
||||
if recv_src is not None and recv_object_size_tensor is not None:
|
||||
recv_object_tensor = torch.empty(recv_object_size_tensor.item(), dtype=torch.uint8)
|
||||
if is_nccl_backend:
|
||||
recv_object_tensor = recv_object_tensor.to(current_device)
|
||||
filling_ops_queue(recv_object_tensor, dist.irecv, recv_src, ops, recv_group)
|
||||
|
||||
if len(ops) > 0:
|
||||
reqs = dist.batch_isend_irecv(ops)
|
||||
for req in reqs:
|
||||
req.wait()
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# See the comment in `_batch_send_recv_tensor`
|
||||
# torch.cuda.synchronize()
|
||||
|
||||
if recv_object_tensor is not None and recv_object_size_tensor is not None:
|
||||
recv_object_tensor = recv_object_tensor.type(torch.uint8)
|
||||
if recv_object_tensor.device != torch.device("cpu"):
|
||||
recv_object_tensor = recv_object_tensor.cpu()
|
||||
|
||||
unpickle_object = _cuda_safe_tensor_to_object(
|
||||
recv_object_tensor, recv_object_size_tensor.item())
|
||||
|
||||
if (
|
||||
isinstance(unpickle_object, torch.Tensor)
|
||||
and unpickle_object.device.index != torch.cuda.current_device()
|
||||
):
|
||||
unpickle_object = unpickle_object.cuda()
|
||||
|
||||
return unpickle_object
|
||||
|
||||
|
||||
def _check_if_fast_send_available(object):
|
||||
if type(object) is torch.Tensor:
|
||||
return True
|
||||
elif type(object) is list:
|
||||
is_list_of_tensor = all([type(v) is torch.Tensor for v in object])
|
||||
return is_list_of_tensor
|
||||
elif type(object) is dict:
|
||||
is_dict_of_tensor = all([type(k) is str and type(
|
||||
v) is torch.Tensor for k, v in object.items()])
|
||||
|
||||
return is_dict_of_tensor
|
||||
return False
|
||||
|
||||
|
||||
def _communicate(
|
||||
object,
|
||||
send_dst: Optional[int],
|
||||
recv_src: Optional[int],
|
||||
send_group: Optional[ProcessGroup] = None,
|
||||
recv_group: Optional[ProcessGroup] = None,
|
||||
) -> Any:
|
||||
if c10d._rank_not_in_group(send_group) or c10d._rank_not_in_group(recv_group):
|
||||
c10d._warn_not_in_group("_communicate")
|
||||
return
|
||||
|
||||
current_send_device, is_send_nccl_backend = check_device(send_group)
|
||||
current_recv_device, is_recv_nccl_backend = check_device(recv_group)
|
||||
|
||||
is_nccl_backend = is_send_nccl_backend and is_recv_nccl_backend
|
||||
|
||||
assert current_send_device == current_recv_device
|
||||
current_device = current_send_device
|
||||
|
||||
assert (send_dst is not None) or (recv_src is not None)
|
||||
|
||||
can_fast_send = False
|
||||
send_metadata = None
|
||||
if send_dst is not None:
|
||||
can_fast_send = _check_if_fast_send_available(object) and is_nccl_backend
|
||||
if not can_fast_send:
|
||||
send_metadata = P2PMetadata(P2PDataType.serialization, object)
|
||||
else:
|
||||
if type(object) is torch.Tensor:
|
||||
data_type = P2PDataType.tensor
|
||||
content = TensorMetadata(None, object.shape, object.dtype, object.requires_grad)
|
||||
elif type(object) is list:
|
||||
data_type = P2PDataType.list
|
||||
content = []
|
||||
for v in object:
|
||||
content.append(TensorMetadata(None, v.shape, v.dtype, v.requires_grad))
|
||||
elif type(object) is dict:
|
||||
data_type = P2PDataType.dict
|
||||
content = []
|
||||
for k, v in object.items():
|
||||
content.append(TensorMetadata(k, v.shape, v.dtype, v.requires_grad))
|
||||
else:
|
||||
raise ValueError('Cannot send object of type {}'.format(type(object)))
|
||||
send_metadata = P2PMetadata(data_type, content)
|
||||
|
||||
recv_metadata = _send_recv_serialization_object(send_metadata, send_dst, recv_src, send_group, recv_group, current_device, is_nccl_backend)
|
||||
if recv_metadata is not None:
|
||||
assert type(recv_metadata) is P2PMetadata
|
||||
if recv_metadata.data_type == P2PDataType.serialization:
|
||||
return recv_metadata.content
|
||||
if not can_fast_send and send_dst is not None:
|
||||
return
|
||||
|
||||
send_tensor_list = None
|
||||
if type(object) is torch.Tensor:
|
||||
send_tensor_list = object
|
||||
elif type(object) is list:
|
||||
send_tensor_list = object
|
||||
elif type(object) is dict:
|
||||
send_tensor_list = list(object.values())
|
||||
|
||||
recv_buffer = _batch_send_recv_tensor(send_tensor_list, recv_metadata, send_dst, recv_src, send_group, recv_group, current_device)
|
||||
|
||||
if recv_metadata is not None:
|
||||
assert recv_buffer is not None
|
||||
if recv_metadata.data_type in [P2PDataType.tensor, P2PDataType.list]:
|
||||
return recv_buffer
|
||||
elif recv_metadata.data_type == P2PDataType.dict:
|
||||
return {
|
||||
k: v
|
||||
for k, v in zip(
|
||||
[m.key for m in recv_metadata.content],
|
||||
recv_buffer,
|
||||
)
|
||||
}
|
||||
else:
|
||||
raise ValueError('Unknown data type {}'.format(recv_metadata.data_type))
|
||||
|
||||
|
||||
def _send_object(object: Any, src: int, dst: int, group: ProcessGroup) -> None:
|
||||
"""send anything to dst rank
|
||||
|
||||
|
@ -141,8 +411,7 @@ def _send_object(object: Any, src: int, dst: int, group: ProcessGroup) -> None:
|
|||
Returns:
|
||||
None
|
||||
"""
|
||||
# then broadcast safely
|
||||
_broadcast_object_list([object], src, group)
|
||||
_communicate(object, send_dst=dst, recv_src=None, send_group=group)
|
||||
|
||||
|
||||
def _recv_object(src: int, dst: int, group: ProcessGroup) -> Any:
|
||||
|
@ -154,10 +423,7 @@ def _recv_object(src: int, dst: int, group: ProcessGroup) -> Any:
|
|||
Returns:
|
||||
Any: Object received from src.
|
||||
"""
|
||||
object_list = [None]
|
||||
_broadcast_object_list(object_list, src, group)
|
||||
|
||||
return object_list[0]
|
||||
return _communicate(None, send_dst=None, recv_src=src, recv_group=group)
|
||||
|
||||
|
||||
def _p2p_comm(
|
||||
|
@ -302,6 +568,64 @@ class PipelineP2PCommunication:
|
|||
cur_rank = self.stage_manager.get_rank()
|
||||
_send_object(input_object, cur_rank, prev_rank, self.stage_manager.get_p2p_process_group(cur_rank, prev_rank))
|
||||
|
||||
def send_forward_recv_backward(self, input_object: Any, next_rank: int = None) -> Any:
|
||||
"""Sends the gradient tensor to and copy the gradient tensor from the next stage in pipeline
|
||||
|
||||
Args:
|
||||
input_object (Any): Object to be sent.
|
||||
next_rank (int, optional): The rank of the sender and recipient of the tensor
|
||||
"""
|
||||
if next_rank is None:
|
||||
next_rank = self.stage_manager.get_next_rank()
|
||||
|
||||
cur_rank = self.stage_manager.get_rank()
|
||||
group = self.stage_manager.get_p2p_process_group(cur_rank, next_rank)
|
||||
return _communicate(
|
||||
input_object, next_rank, next_rank,
|
||||
send_group=group, recv_group=group,
|
||||
)
|
||||
|
||||
def send_backward_recv_forward(self, input_object: Any, prev_rank: int = None) -> Any:
|
||||
"""Sends the gradient tensor to and copy the gradient tensor from the previous stage in pipeline
|
||||
|
||||
Args:
|
||||
input_object (Any): Object to be sent.
|
||||
prev_rank (int, optional): The rank of the sender and recipient of the tensor
|
||||
"""
|
||||
if prev_rank is None:
|
||||
prev_rank = self.stage_manager.get_prev_rank()
|
||||
|
||||
cur_rank = self.stage_manager.get_rank()
|
||||
group = self.stage_manager.get_p2p_process_group(prev_rank, cur_rank)
|
||||
return _communicate(
|
||||
input_object, prev_rank, prev_rank,
|
||||
send_group=group, recv_group=group,
|
||||
)
|
||||
|
||||
def send_forward_recv_forward(self, input_object: Any, prev_rank: int = None, next_rank: int = None) -> Any:
|
||||
"""Sends the gradient tensor to the previous stage and copy the input tensor from the previous stage in pipeline.
|
||||
|
||||
Args:
|
||||
input_object (Any): Object to be sent.
|
||||
prev_rank (int, optional): The rank of the sender of the tensor
|
||||
next_rank (int, optional): The rank of the recipient of the tensor
|
||||
"""
|
||||
if prev_rank is None:
|
||||
prev_rank = self.stage_manager.get_prev_rank()
|
||||
if next_rank is None:
|
||||
next_rank = self.stage_manager.get_next_rank()
|
||||
|
||||
cur_rank = self.stage_manager.get_rank()
|
||||
recv_group = self.stage_manager.get_p2p_process_group(prev_rank, cur_rank)
|
||||
send_group = self.stage_manager.get_p2p_process_group(cur_rank, next_rank)
|
||||
return _communicate(
|
||||
input_object,
|
||||
send_dst=next_rank,
|
||||
recv_src=prev_rank,
|
||||
send_group=send_group,
|
||||
recv_group=recv_group,
|
||||
)
|
||||
|
||||
def p2p_communicate(
|
||||
self, output_object: Any, recv_pre: bool, peer: int = None, comm_dtype: torch.dtype = torch.float16
|
||||
) -> None:
|
||||
|
|
|
@ -127,6 +127,17 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
if not self.stage_manager.is_last_stage():
|
||||
self.comm.send_forward(output_object, next_rank)
|
||||
|
||||
def send_forward_recv_backward(self, output_object: Any, next_rank: int = None) -> Any:
|
||||
"""Sends the input tensor to the next stage and copy the gradient tensor from the next stage in pipeline.
|
||||
For 1F1B.
|
||||
|
||||
Args:
|
||||
output_object (Any): Object to be sent.
|
||||
next_rank (int, optional): The rank of the recipient of the tensor.
|
||||
"""
|
||||
if not self.stage_manager.is_last_stage():
|
||||
return self.comm.send_forward_recv_backward(output_object, next_rank)
|
||||
|
||||
def send_backward(self, input_object: Any, prev_rank: int = None) -> None:
|
||||
"""Sends the gradient tensor to the previous stage in pipeline.
|
||||
For 1F1B.
|
||||
|
@ -138,6 +149,33 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
if not self.stage_manager.is_first_stage():
|
||||
self.comm.send_backward(input_object, prev_rank)
|
||||
|
||||
def send_backward_recv_forward(self, output_object: Any, prev_rank: int = None) -> Any:
|
||||
"""Sends the gradient tensor to the previous stage and copy the input tensor from the previous stage in pipeline.
|
||||
For 1F1B.
|
||||
|
||||
Args:
|
||||
output_object (Any): Object to be sent.
|
||||
prev_rank (int, optional): The rank of the recipient of the tensor.
|
||||
"""
|
||||
if not self.stage_manager.is_first_stage():
|
||||
return self.comm.send_backward_recv_forward(output_object, prev_rank)
|
||||
|
||||
def send_forward_recv_forward(self, input_object: Any, prev_rank: int = None, next_rank: int = None) -> Any:
|
||||
"""Sends the input tensor to the next stage and copy the input tensor from the previous stage in pipeline.
|
||||
For 1F1B.
|
||||
|
||||
Args:
|
||||
input_object (Any): Object to be sent.
|
||||
prev_rank (int, optional): The previous rank of the recipient of the tensor.
|
||||
next_rank (int, optional): The next rank of the recipient of the tensor.
|
||||
"""
|
||||
if self.stage_manager.is_first_stage():
|
||||
return self.comm.send_forward(input_object, next_rank)
|
||||
elif self.stage_manager.is_last_stage():
|
||||
return self.comm.recv_forward(prev_rank)
|
||||
else:
|
||||
return self.comm.send_forward_recv_forward(input_object, prev_rank, next_rank)
|
||||
|
||||
def forward_step(
|
||||
self,
|
||||
model: Module,
|
||||
|
@ -291,7 +329,6 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
|
||||
if not last_iteration:
|
||||
input_obj = self.recv_forward()
|
||||
|
||||
else:
|
||||
# TODO adjust here
|
||||
self.send_forward(output_obj)
|
||||
|
|
|
@ -413,6 +413,7 @@ def get_llama_flash_attention_forward():
|
|||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
assert q_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4."
|
||||
|
|
|
@ -183,7 +183,11 @@ def main():
|
|||
model_numel = get_model_numel(model)
|
||||
coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}")
|
||||
performance_evaluator = PerformanceEvaluator(
|
||||
model_numel, args.grad_checkpoint, args.ignore_steps, dp_world_size=dp_size
|
||||
model_numel,
|
||||
model.config.num_hidden_layers,
|
||||
model.config.hidden_size,
|
||||
model.config.vocab_size,
|
||||
args.grad_checkpoint, args.ignore_steps, dp_world_size=dp_size
|
||||
)
|
||||
|
||||
optimizer = HybridAdam(model.parameters())
|
||||
|
|
|
@ -58,6 +58,9 @@ class PerformanceEvaluator:
|
|||
def __init__(
|
||||
self,
|
||||
model_numel: int,
|
||||
num_layers: int,
|
||||
hidden_size: int,
|
||||
vocab_size: int,
|
||||
enable_grad_checkpoint: bool = False,
|
||||
ignore_steps: int = 0,
|
||||
dp_world_size: Optional[int] = None,
|
||||
|
@ -65,12 +68,16 @@ class PerformanceEvaluator:
|
|||
self.model_numel = model_numel
|
||||
self.enable_grad_checkpoint = enable_grad_checkpoint
|
||||
self.ignore_steps = ignore_steps
|
||||
self.num_layers = num_layers
|
||||
self.hidden_size = hidden_size
|
||||
self.vocab_size = vocab_size
|
||||
|
||||
self.coordinator = DistCoordinator()
|
||||
self.dp_world_size = dp_world_size or self.coordinator.world_size
|
||||
self.disable: bool = False
|
||||
self.timer = Timer()
|
||||
self.num_samples: int = 0
|
||||
self.flop_megatron = 0
|
||||
self.flop: int = 0
|
||||
|
||||
def on_step_start(self, step: int) -> None:
|
||||
|
@ -89,17 +96,20 @@ class PerformanceEvaluator:
|
|||
batch_size, seq_len = input_ids.shape
|
||||
|
||||
self.num_samples += batch_size
|
||||
checkpoint_activations_factor = (3 + int(self.enable_grad_checkpoint))
|
||||
self.flop_megatron += (24 * checkpoint_activations_factor * batch_size * seq_len * self.num_layers * (self.hidden_size**2)) * (1. + (seq_len / (6. * self.hidden_size)) + (self.vocab_size / (16. * self.num_layers * self.hidden_size)))
|
||||
self.flop += batch_size * seq_len * self.model_numel * 2 * (3 + int(self.enable_grad_checkpoint))
|
||||
|
||||
def on_fit_end(self) -> None:
|
||||
avg_duration = all_reduce_mean(self.timer.duration, self.coordinator.world_size)
|
||||
avg_throughput = self.num_samples * self.dp_world_size / (avg_duration + 1e-12)
|
||||
mp_world_size = self.coordinator.world_size // self.dp_world_size
|
||||
avg_tflops_per_gpu_megatron = self.flop_megatron / 1e12 / (avg_duration + 1e-12) / mp_world_size
|
||||
avg_tflops_per_gpu = self.flop / 1e12 / (avg_duration + 1e-12) / mp_world_size
|
||||
self.coordinator.print_on_master(
|
||||
f"num_samples: {self.num_samples}, dp_world_size: {self.dp_world_size}, flop: {self.flop}, avg_duration: {avg_duration}, "
|
||||
f"num_samples: {self.num_samples}, dp_world_size: {self.dp_world_size}, flop_megatron: {self.flop_megatron}, flop: {self.flop}, avg_duration: {avg_duration}, "
|
||||
f"avg_throughput: {avg_throughput}"
|
||||
)
|
||||
self.coordinator.print_on_master(
|
||||
f"Throughput: {avg_throughput:.2f} samples/sec, TFLOPS per GPU: {avg_tflops_per_gpu:.2f}"
|
||||
f"Throughput: {avg_throughput:.2f} samples/sec, TFLOPS per GPU by Megatron: {avg_tflops_per_gpu_megatron:.2f}, TFLOPS per GPU: {avg_tflops_per_gpu:.2f}"
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue