mirror of https://github.com/InternLM/InternLM
feat(core/scheduler): support pipeline parallel (#98)
* feat(utils/writer.py): support tensorboard writer * feat(utils/writer.py): add class comment * feat(core): support pipeline parallel * fix(core): fix demo running error * feat(solver/optimizer): add pp zero optimizer * fix(solver/optimizer): fix word spelling error * feat(core/scheduler): add new dir scheduler in core/ * fix(core): fix ci lint error * feat(solver/optimizer): merge pp and nopp optimizer * doc(usage.md): update usage doc * feat(core/scheduler): support post func * feat(core/scheduler): add dtype para in pp sche and update func get_tensor_shape * feat(core/scheduler): add _load_micro_batch in base scheduler * feat(core/scheduler): support optimizer overlap communication in pp scheduler * feat(core/scheduler): delete data process func code * feat(core/trainer): schedule pre processing for all schedule --------- Co-authored-by: 黄婷 <huangting3@CN0014010744M.local> Co-authored-by: huangting.p <huangting@sensetime.com>pull/91/head
parent
e0d6a3f84f
commit
762ab297ee
|
@ -141,4 +141,5 @@ small_demo/
|
|||
core.*
|
||||
|
||||
# Run
|
||||
llm_ckpts
|
||||
llm_ckpts
|
||||
events.*
|
|
@ -49,5 +49,5 @@ repos:
|
|||
args:
|
||||
[
|
||||
'--rcfile=.pylintrc',
|
||||
'--disable=C0114,C0415,W0212,W0235,W0238,W0621,C0103,R1735,C2801,E0402,C0412,W0719,R1728,W1514,W0718,W0105,W0707,C0209,W0703'
|
||||
'--disable=C0114,C0415,W0212,W0235,W0238,W0621,C0103,R1735,C2801,E0402,C0412,W0719,R1728,W1514,W0718,W0105,W0707,C0209,W0703,W1203'
|
||||
]
|
|
@ -123,6 +123,7 @@ tensor parallel: tensor parallel size, usually the number of GPUs per node, only
|
|||
"""
|
||||
parallel = dict(
|
||||
zero1=8,
|
||||
pipeline=2,
|
||||
)
|
||||
|
||||
cudnn_deterministic = False
|
||||
|
|
|
@ -174,7 +174,7 @@ parallel = dict(
|
|||
- When `size <= 0`, the size of the zero1 process group is equal to the size of the data parallel process group, so the optimizer state parameters will be split within the data parallel range.
|
||||
- When `size == 1`, zero1 is not used, and all data parallel groups retain the complete optimizer state parameters.
|
||||
- When `size > 1` and `size <= data_parallel_world_size`, the zero1 process group is a subset of the data parallel process group.
|
||||
- pipeline: pipeline parallel size, currently only supports 1, default value is 1
|
||||
- pipeline: pipeline parallel size, default value is 1
|
||||
- tensor: tensor parallel size, usually the number of GPUs per node, default value is 1
|
||||
|
||||
Note: `Data parallel size = Total number of GPUs / Pipeline parallel size / Tensor parallel size`
|
||||
|
|
|
@ -159,7 +159,7 @@ parallel = dict(
|
|||
- 当`size <= 0`,则 zero1 进程组的大小等于数据并行进程组的大小,因此优化器状态参数将在数据并行范围内分配
|
||||
- 当`size == 1`,则不使用 zero1 ,所有数据并行组保留完整的优化器状态参数
|
||||
- 当`size > 1`且`size <= data_parallel_world_size`,则 zero1 进程组是数据并行进程组的子集
|
||||
- pipeline:流水线并行大小,目前只支持 1,默认值为 1
|
||||
- pipeline:流水线并行大小,默认值为 1
|
||||
- tensor:张量并行大小,通常是每个节点的 GPU 数量,默认值为 1
|
||||
|
||||
注意:`数据并行大小 = 总的 GPU 数目 / 流水线并行大小 / 张量并行大小`
|
||||
|
|
|
@ -0,0 +1,26 @@
|
|||
from .p2p import (
|
||||
recv_backward,
|
||||
recv_forward,
|
||||
send_backward,
|
||||
send_backward_recv_backward,
|
||||
send_backward_recv_forward,
|
||||
send_forward,
|
||||
send_forward_backward_recv_forward_backward,
|
||||
send_forward_recv_backward,
|
||||
send_forward_recv_forward,
|
||||
)
|
||||
from .utils import recv_obj_meta, send_obj_meta
|
||||
|
||||
__all__ = [
|
||||
"send_forward",
|
||||
"send_forward_recv_forward",
|
||||
"send_forward_backward_recv_forward_backward",
|
||||
"send_backward",
|
||||
"send_backward_recv_backward",
|
||||
"send_backward_recv_forward",
|
||||
"send_forward_recv_backward",
|
||||
"recv_backward",
|
||||
"recv_forward",
|
||||
"send_obj_meta",
|
||||
"recv_obj_meta",
|
||||
]
|
|
@ -0,0 +1,442 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/communication
|
||||
|
||||
import operator
|
||||
from functools import reduce
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from internlm.core.context import ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.utils.common import get_current_device
|
||||
|
||||
from .utils import gather_split_1d_tensor, split_tensor_into_1d_equal_chunks
|
||||
|
||||
TensorShape = Union[torch.Size, List[int], Tuple[int]]
|
||||
|
||||
|
||||
def _get_tensor_shape(tensor_shape: TensorShape, chunk_tensor: bool = False) -> Tuple[TensorShape, bool]:
|
||||
"""get the exact tensor shape when communicating and return whether the tensor is a chunk
|
||||
|
||||
Args:
|
||||
tensor_shape (:class:`torch.Size`): shape of tensor
|
||||
chunk_tensor (bool, optional): whether to chunk tensor, defaults to False
|
||||
|
||||
Returns:
|
||||
Tuple[Union[:class:`torch.Size`, List[int], Tuple[int]], bool]: exact tensor shape, whether to chunk tensor
|
||||
"""
|
||||
if chunk_tensor:
|
||||
tensor_chunk_shape = reduce(operator.mul, tensor_shape, 1)
|
||||
tensor_parallel_world_size = gpc.get_world_size(ParallelMode.TENSOR)
|
||||
if tensor_chunk_shape % tensor_parallel_world_size == 0:
|
||||
tensor_chunk_shape = tensor_chunk_shape // tensor_parallel_world_size
|
||||
else:
|
||||
tensor_chunk_shape = tensor_shape
|
||||
chunk_tensor = False
|
||||
else:
|
||||
tensor_chunk_shape = tensor_shape
|
||||
return tensor_chunk_shape, chunk_tensor
|
||||
|
||||
|
||||
def create_recv_buffer_with_shapes(recv_shapes, dtype, scatter_gather_tensors):
|
||||
if isinstance(recv_shapes, torch.Size):
|
||||
recv_chunk_shape, recv_split = _get_tensor_shape(recv_shapes, scatter_gather_tensors)
|
||||
buffer_recv = torch.empty(recv_chunk_shape, requires_grad=True, device=get_current_device(), dtype=dtype)
|
||||
return buffer_recv, recv_split
|
||||
buffer_recv = []
|
||||
for recv_shape in recv_shapes:
|
||||
recv_chunk_shape, recv_split = _get_tensor_shape(recv_shape, scatter_gather_tensors)
|
||||
tensor_recv = torch.empty(recv_chunk_shape, requires_grad=True, device=get_current_device(), dtype=dtype)
|
||||
buffer_recv.append(tensor_recv)
|
||||
return buffer_recv, recv_split
|
||||
|
||||
|
||||
def process_object_to_send(object_send, scatter_gather_tensors):
|
||||
if isinstance(object_send, torch.Tensor):
|
||||
send_split = _get_tensor_shape(object_send.shape, scatter_gather_tensors)[1]
|
||||
if send_split:
|
||||
object_send = split_tensor_into_1d_equal_chunks(object_send)
|
||||
return object_send
|
||||
|
||||
object_send_list = []
|
||||
for tensor_send in object_send:
|
||||
send_split = _get_tensor_shape(tensor_send.shape, scatter_gather_tensors)[1]
|
||||
if send_split:
|
||||
object_send_list.append(split_tensor_into_1d_equal_chunks(tensor_send))
|
||||
else:
|
||||
object_send_list.append(tensor_send)
|
||||
object_send = tuple(object_send_list)
|
||||
|
||||
return object_send
|
||||
|
||||
|
||||
def filling_ops_queue(obj, comm_op, comm_rank, ops_queue):
|
||||
if isinstance(obj, torch.Tensor):
|
||||
op_to_add = dist.P2POp(comm_op, obj, comm_rank)
|
||||
ops_queue.append(op_to_add)
|
||||
else:
|
||||
for tensor_to_comm in obj:
|
||||
op_to_add = dist.P2POp(comm_op, tensor_to_comm, comm_rank)
|
||||
ops_queue.append(op_to_add)
|
||||
|
||||
|
||||
def _communicate(
|
||||
object_send_next: Union[torch.Tensor, List[torch.Tensor]] = None,
|
||||
object_send_prev: Union[torch.Tensor, List[torch.Tensor]] = None,
|
||||
recv_prev: bool = False,
|
||||
recv_next: bool = False,
|
||||
recv_prev_shape: Union[torch.Size, List[torch.Size]] = None,
|
||||
recv_next_shape: Union[torch.Size, List[torch.Size]] = None,
|
||||
prev_rank: int = None,
|
||||
next_rank: int = None,
|
||||
dtype: torch.dtype = None,
|
||||
scatter_gather_tensors: bool = False,
|
||||
) -> Tuple[Union[torch.Tensor, List[torch.Tensor]]]:
|
||||
"""
|
||||
Adapted from megatron.p2p_communication.
|
||||
Communicate tensors between stages. Used as helper method in other
|
||||
communication methods that are used in pipeline schedule.
|
||||
Takes the following arguments:
|
||||
object_send_next (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): tensor to send to next rank
|
||||
(no tensor sent if set to None).
|
||||
object_send_prev (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): tensor to send to prev rank
|
||||
(no tensor sent if set to None).
|
||||
recv_prev (bool): boolean for whether tensor should be received from
|
||||
previous rank.
|
||||
recv_next (bool): boolean for whether tensor should be received from
|
||||
next rank.
|
||||
recv_prev_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): shape of the tensor to be received
|
||||
from the previous stage, defualts to None.
|
||||
recv_next_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): shape of the tensor to be received
|
||||
from the next stage, defualts to None.
|
||||
prev_rank (int): the rank of the previous pipeline stage, defualts to None,
|
||||
next_rank (int): the rank of the next pipeline stage, defualts to None,
|
||||
dtype (torch.dtype): data type of intermediate buffers, defaults to None
|
||||
scatter_gather_tensors (bool): whether to scatter and gather tensor between pipeline stages, defaults to False
|
||||
|
||||
Returns:
|
||||
Tuple[Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]]: returns tensor_recv_prev, tensor_recv_next
|
||||
"""
|
||||
|
||||
# Create placeholder tensors for receive in forward and backward directions
|
||||
# if needed.
|
||||
tensor_recv_prev = None
|
||||
tensor_recv_next = None
|
||||
|
||||
if recv_prev:
|
||||
assert recv_prev_shape is not None
|
||||
tensor_recv_prev, recv_prev_split = create_recv_buffer_with_shapes(
|
||||
recv_prev_shape, dtype, scatter_gather_tensors
|
||||
)
|
||||
|
||||
if recv_next:
|
||||
assert recv_next_shape is not None
|
||||
tensor_recv_next, recv_next_split = create_recv_buffer_with_shapes(
|
||||
recv_next_shape, dtype, scatter_gather_tensors
|
||||
)
|
||||
|
||||
if object_send_prev is not None or recv_prev:
|
||||
if prev_rank is None:
|
||||
prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE)
|
||||
|
||||
if object_send_next is not None or recv_next:
|
||||
if next_rank is None:
|
||||
next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)
|
||||
|
||||
if object_send_prev is not None:
|
||||
object_send_prev = process_object_to_send(object_send_prev, scatter_gather_tensors)
|
||||
|
||||
if object_send_next is not None:
|
||||
object_send_next = process_object_to_send(object_send_next, scatter_gather_tensors)
|
||||
|
||||
ops = []
|
||||
if object_send_prev is not None:
|
||||
filling_ops_queue(object_send_prev, dist.isend, prev_rank, ops)
|
||||
|
||||
if tensor_recv_prev is not None:
|
||||
filling_ops_queue(tensor_recv_prev, dist.irecv, prev_rank, ops)
|
||||
|
||||
if tensor_recv_next is not None:
|
||||
filling_ops_queue(tensor_recv_next, dist.irecv, next_rank, ops)
|
||||
|
||||
if object_send_next is not None:
|
||||
filling_ops_queue(object_send_next, dist.isend, next_rank, ops)
|
||||
|
||||
if len(ops) > 0:
|
||||
reqs = dist.batch_isend_irecv(ops)
|
||||
for req in reqs:
|
||||
req.wait()
|
||||
# To protect against race condition when using batch_isend_irecv().
|
||||
torch.cuda.synchronize()
|
||||
|
||||
if recv_prev and recv_prev_split:
|
||||
if isinstance(tensor_recv_prev, torch.Tensor):
|
||||
tensor_recv_prev = gather_split_1d_tensor(tensor_recv_prev).view(recv_prev_shape).requires_grad_()
|
||||
else:
|
||||
for index in range(len(tensor_recv_prev)):
|
||||
tensor_recv_prev[index] = (
|
||||
gather_split_1d_tensor(tensor_recv_prev[index]).view(recv_prev_shape[index]).requires_grad_()
|
||||
)
|
||||
|
||||
if recv_next and recv_next_split:
|
||||
if isinstance(tensor_recv_next, torch.Tensor):
|
||||
tensor_recv_next = gather_split_1d_tensor(tensor_recv_next).view(recv_next_shape).requires_grad_()
|
||||
else:
|
||||
for index in range(len(tensor_recv_next)):
|
||||
tensor_recv_next[index] = (
|
||||
gather_split_1d_tensor(tensor_recv_next[index]).view(recv_next_shape[index]).requires_grad_()
|
||||
)
|
||||
|
||||
return tensor_recv_prev, tensor_recv_next
|
||||
|
||||
|
||||
def recv_forward(
|
||||
input_tensor_shape, prev_rank=None, dtype=torch.float, scatter_gather_tensors=False
|
||||
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
||||
"""Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
|
||||
|
||||
Args:
|
||||
input_tensor_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor
|
||||
to be received.
|
||||
prev_rank (int, optional): The rank of the source of the tensor.
|
||||
|
||||
Returns:
|
||||
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input tensor or input tensor list.
|
||||
"""
|
||||
if gpc.is_pipeline_first_stage():
|
||||
input_tensor = None
|
||||
else:
|
||||
input_tensor, _ = _communicate(
|
||||
recv_prev=True,
|
||||
recv_prev_shape=input_tensor_shape,
|
||||
prev_rank=prev_rank,
|
||||
dtype=dtype,
|
||||
scatter_gather_tensors=scatter_gather_tensors,
|
||||
)
|
||||
return input_tensor
|
||||
|
||||
|
||||
def recv_backward(
|
||||
output_grad_shape, next_rank=None, dtype=torch.float, scatter_gather_tensors=False
|
||||
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
||||
"""Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
|
||||
|
||||
Args:
|
||||
output_grad_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor
|
||||
to be received.
|
||||
next_rank (int, optional): The rank of the source of the tensor.
|
||||
|
||||
Returns:
|
||||
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input gradient tensor or gradident tensor list.
|
||||
"""
|
||||
if gpc.is_pipeline_last_stage():
|
||||
output_tensor_grad = None
|
||||
else:
|
||||
_, output_tensor_grad = _communicate(
|
||||
recv_next=True,
|
||||
recv_next_shape=output_grad_shape,
|
||||
next_rank=next_rank,
|
||||
dtype=dtype,
|
||||
scatter_gather_tensors=scatter_gather_tensors,
|
||||
)
|
||||
return output_tensor_grad
|
||||
|
||||
|
||||
def send_forward(output_tensor, next_rank=None, scatter_gather_tensors=False) -> None:
|
||||
"""Sends the input tensor to the next stage in pipeline.
|
||||
|
||||
Args:
|
||||
output_tensor (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent.
|
||||
next_rank (int, optional): The rank of the recipient of the tensor.
|
||||
"""
|
||||
if not gpc.is_pipeline_last_stage():
|
||||
_communicate(object_send_next=output_tensor, next_rank=next_rank, scatter_gather_tensors=scatter_gather_tensors)
|
||||
|
||||
|
||||
def send_backward(input_tensor_grad, prev_rank=None, scatter_gather_tensors=False) -> None:
|
||||
"""Sends the gradient tensor to the previous stage in pipeline.
|
||||
|
||||
Args:
|
||||
input_tensor_grad (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent
|
||||
prev_rank (int, optional): The rank of the recipient of the tensor
|
||||
"""
|
||||
if not gpc.is_pipeline_first_stage():
|
||||
_communicate(
|
||||
object_send_prev=input_tensor_grad, prev_rank=prev_rank, scatter_gather_tensors=scatter_gather_tensors
|
||||
)
|
||||
|
||||
|
||||
def send_forward_recv_backward(
|
||||
output_tensor, output_grad_shape, recv_next=True, next_rank=None, dtype=torch.float, scatter_gather_tensors=False
|
||||
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
||||
"""Batched communication operation. Sends the input tensor to the
|
||||
next stage in pipeline, while receives the gradient tensor from the
|
||||
next stage in pipeline as the input gradient tensor of this stage.
|
||||
|
||||
Args:
|
||||
output_tensor (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent.
|
||||
output_grad_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor
|
||||
to be received.
|
||||
|
||||
Returns:
|
||||
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input gradient tensor.
|
||||
"""
|
||||
if gpc.is_pipeline_last_stage():
|
||||
output_tensor_grad = None
|
||||
else:
|
||||
_, output_tensor_grad = _communicate(
|
||||
object_send_next=output_tensor,
|
||||
recv_next=recv_next,
|
||||
recv_next_shape=output_grad_shape,
|
||||
next_rank=next_rank,
|
||||
dtype=dtype,
|
||||
scatter_gather_tensors=scatter_gather_tensors,
|
||||
)
|
||||
return output_tensor_grad
|
||||
|
||||
|
||||
def send_backward_recv_forward(
|
||||
input_tensor_grad,
|
||||
input_tensor_shape,
|
||||
recv_prev=True,
|
||||
prev_rank=None,
|
||||
dtype=torch.float,
|
||||
scatter_gather_tensors=False,
|
||||
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
||||
"""Batched communication operation. Sends the gradient tensor to the
|
||||
previous stage in pipeline, while receives the output tensor from the
|
||||
previous stage in pipeline as the input of this stage.
|
||||
|
||||
Args:
|
||||
input_tensor_grad (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent.
|
||||
input_tensor_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor
|
||||
to be received.
|
||||
|
||||
Returns:
|
||||
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input tensor.
|
||||
"""
|
||||
if gpc.is_pipeline_first_stage():
|
||||
input_tensor = None
|
||||
else:
|
||||
input_tensor, _ = _communicate(
|
||||
object_send_prev=input_tensor_grad,
|
||||
recv_prev=recv_prev,
|
||||
recv_prev_shape=input_tensor_shape,
|
||||
prev_rank=prev_rank,
|
||||
dtype=dtype,
|
||||
scatter_gather_tensors=scatter_gather_tensors,
|
||||
)
|
||||
return input_tensor
|
||||
|
||||
|
||||
def send_forward_recv_forward(
|
||||
output_tensor,
|
||||
input_tensor_shape,
|
||||
recv_prev=True,
|
||||
prev_rank=None,
|
||||
next_rank=None,
|
||||
dtype=torch.float,
|
||||
scatter_gather_tensors=False,
|
||||
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
||||
"""Batched communication operation. Sends the input tensor to the
|
||||
next stage in pipeline, while receives the output tensor from the
|
||||
previous stage in pipeline as the input of this stage.
|
||||
|
||||
Args:
|
||||
output_tensor (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent.
|
||||
input_tensor_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor
|
||||
to be received.
|
||||
|
||||
Returns:
|
||||
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input tensor.
|
||||
"""
|
||||
input_tensor, _ = _communicate(
|
||||
object_send_next=output_tensor,
|
||||
recv_prev=recv_prev,
|
||||
recv_prev_shape=input_tensor_shape,
|
||||
prev_rank=prev_rank,
|
||||
next_rank=next_rank,
|
||||
dtype=dtype,
|
||||
scatter_gather_tensors=scatter_gather_tensors,
|
||||
)
|
||||
return input_tensor
|
||||
|
||||
|
||||
def send_backward_recv_backward(
|
||||
input_tensor_grad,
|
||||
output_grad_shape,
|
||||
recv_next=True,
|
||||
prev_rank=None,
|
||||
next_rank=None,
|
||||
dtype=torch.float,
|
||||
scatter_gather_tensors=False,
|
||||
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
||||
"""Batched communication operation. Sends the gradient tensor to the
|
||||
previous stage in pipeline, while receives the gradient tensor from the
|
||||
next member in pipeline as the input of this stage.
|
||||
|
||||
Args:
|
||||
input_tensor_grad (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent.
|
||||
output_grad_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor
|
||||
to be received.
|
||||
|
||||
Returns:
|
||||
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input gradient tensor.
|
||||
"""
|
||||
_, output_tensor_grad = _communicate(
|
||||
object_send_prev=input_tensor_grad,
|
||||
recv_next=recv_next,
|
||||
recv_next_shape=output_grad_shape,
|
||||
prev_rank=prev_rank,
|
||||
next_rank=next_rank,
|
||||
dtype=dtype,
|
||||
scatter_gather_tensors=scatter_gather_tensors,
|
||||
)
|
||||
return output_tensor_grad
|
||||
|
||||
|
||||
def send_forward_backward_recv_forward_backward(
|
||||
output_tensor,
|
||||
input_tensor_grad,
|
||||
input_tensor_shape,
|
||||
output_grad_shape,
|
||||
recv_prev=True,
|
||||
recv_next=True,
|
||||
prev_rank=None,
|
||||
next_rank=None,
|
||||
dtype=torch.float,
|
||||
scatter_gather_tensors=False,
|
||||
) -> Tuple[Union[torch.Tensor, List[torch.Tensor]]]:
|
||||
"""Batched communication operation. Sends the input tensor to the next stage in pipeline and
|
||||
the gradient tensor to the previous stage, while receives the input gradient tensor from the
|
||||
next stage and the input tensor from the previous stage.
|
||||
|
||||
Args:
|
||||
output_tensor (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor sent to the next.
|
||||
input_tensor_grad (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor sent to the previous.
|
||||
input_tensor_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor received
|
||||
from the previous.
|
||||
output_grad_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor received
|
||||
from the next.
|
||||
|
||||
Returns:
|
||||
Tuple(Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]], Union[:class:`torch.Tensor`,
|
||||
List[:class:`torch.Tensor`]]): (the input tensor, the input gradient tensor)
|
||||
"""
|
||||
input_tensor, output_tensor_grad = _communicate(
|
||||
object_send_next=output_tensor,
|
||||
object_send_prev=input_tensor_grad,
|
||||
recv_prev=recv_prev,
|
||||
recv_next=recv_next,
|
||||
recv_prev_shape=input_tensor_shape,
|
||||
recv_next_shape=output_grad_shape,
|
||||
prev_rank=prev_rank,
|
||||
next_rank=next_rank,
|
||||
dtype=dtype,
|
||||
scatter_gather_tensors=scatter_gather_tensors,
|
||||
)
|
||||
return input_tensor, output_tensor_grad
|
|
@ -0,0 +1,129 @@
|
|||
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/communication
|
||||
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from internlm.core.context import ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.utils.common import get_current_device
|
||||
|
||||
TensorShape = Union[torch.Size, List[int], Tuple[int]]
|
||||
|
||||
|
||||
def send_meta_helper(obj, next_rank, tensor_kwargs):
|
||||
send_shape = torch.tensor(obj.size(), **tensor_kwargs)
|
||||
send_ndims = torch.tensor(len(obj.size()), **tensor_kwargs)
|
||||
dist.send(send_ndims, next_rank)
|
||||
dist.send(send_shape, next_rank)
|
||||
|
||||
|
||||
def send_obj_meta(obj, need_meta=True, next_rank=None) -> bool:
|
||||
"""Sends obj meta information before sending a specific obj.
|
||||
Since the recipient must know the shape of the obj in p2p communications,
|
||||
meta information of the obj should be sent before communications. This function
|
||||
synchronizes with :func:`recv_obj_meta`.
|
||||
|
||||
Args:
|
||||
obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): obj to be sent.
|
||||
need_meta (bool, optional): If False, meta information won't be sent.
|
||||
next_rank (int): The rank of the next member in pipeline parallel group.
|
||||
|
||||
Returns:
|
||||
bool: False
|
||||
"""
|
||||
if need_meta:
|
||||
if next_rank is None:
|
||||
next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)
|
||||
|
||||
tensor_kwargs = {"dtype": torch.long, "device": get_current_device()}
|
||||
if isinstance(obj, torch.Tensor):
|
||||
send_obj_nums = torch.tensor(1, **tensor_kwargs)
|
||||
dist.send(send_obj_nums, next_rank)
|
||||
send_meta_helper(obj, next_rank, tensor_kwargs)
|
||||
else:
|
||||
send_obj_nums = torch.tensor(len(obj), **tensor_kwargs)
|
||||
dist.send(send_obj_nums, next_rank)
|
||||
for tensor_to_send in obj:
|
||||
send_meta_helper(tensor_to_send, next_rank, tensor_kwargs)
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def recv_meta_helper(prev_rank, tensor_kwargs):
|
||||
recv_ndims = torch.empty((), **tensor_kwargs)
|
||||
dist.recv(recv_ndims, prev_rank)
|
||||
recv_shape = torch.empty(recv_ndims, **tensor_kwargs)
|
||||
dist.recv(recv_shape, prev_rank)
|
||||
return recv_shape
|
||||
|
||||
|
||||
def recv_obj_meta(obj_shape, prev_rank=None) -> torch.Size:
|
||||
"""Receives obj meta information before receiving a specific obj.
|
||||
Since the recipient must know the shape of the obj in p2p communications,
|
||||
meta information of the obj should be received before communications. This function
|
||||
synchronizes with :func:`send_obj_meta`.
|
||||
|
||||
Args:
|
||||
obj_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the obj to be received.
|
||||
prev_rank (int): The rank of the source of the obj.
|
||||
|
||||
Returns:
|
||||
Union[:class:`torch.Size`, List[:class:`torch.Size`]]: The shape of the obj to be received.
|
||||
"""
|
||||
if obj_shape is None:
|
||||
if prev_rank is None:
|
||||
prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE)
|
||||
|
||||
tensor_kwargs = {"dtype": torch.long, "device": get_current_device()}
|
||||
recv_obj_nums = torch.empty((), **tensor_kwargs)
|
||||
dist.recv(recv_obj_nums, prev_rank)
|
||||
if recv_obj_nums.item() == 1:
|
||||
recv_shape = recv_meta_helper(prev_rank, tensor_kwargs)
|
||||
obj_shape = torch.Size(recv_shape)
|
||||
else:
|
||||
obj_shape = []
|
||||
for _ in range(recv_obj_nums.item()):
|
||||
recv_shape = recv_meta_helper(prev_rank, tensor_kwargs)
|
||||
obj_shape.append(torch.Size(recv_shape))
|
||||
|
||||
return obj_shape
|
||||
|
||||
|
||||
def split_tensor_into_1d_equal_chunks(tensor: torch.Tensor, new_buffer=False) -> torch.Tensor:
|
||||
"""Break a tensor into equal 1D chunks.
|
||||
|
||||
Args:
|
||||
tensor (:class:`torch.Tensor`): Tensor to be split before communication.
|
||||
new_buffer (bool, optional): Whether to use a new buffer to store sliced tensor.
|
||||
|
||||
Returns:
|
||||
:class:`torch.Tensor`: The split tensor
|
||||
"""
|
||||
partition_size = torch.numel(tensor) // gpc.get_world_size(ParallelMode.TENSOR)
|
||||
start_index = partition_size * gpc.get_local_rank(ParallelMode.TENSOR)
|
||||
end_index = start_index + partition_size
|
||||
if new_buffer:
|
||||
data = torch.empty(partition_size, dtype=tensor.dtype, device=torch.cuda.current_device(), requires_grad=False)
|
||||
data.copy_(tensor.view(-1)[start_index:end_index])
|
||||
else:
|
||||
data = tensor.view(-1)[start_index:end_index]
|
||||
return data
|
||||
|
||||
|
||||
def gather_split_1d_tensor(tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""Opposite of above function, gather values from model parallel ranks.
|
||||
|
||||
Args:
|
||||
tensor (:class:`torch.Tensor`): Tensor to be gathered after communication.
|
||||
Returns:
|
||||
:class:`torch.Tensor`: The gathered tensor.
|
||||
"""
|
||||
world_size = gpc.get_world_size(ParallelMode.TENSOR)
|
||||
numel = torch.numel(tensor)
|
||||
numel_gathered = world_size * numel
|
||||
gathered = torch.empty(numel_gathered, dtype=tensor.dtype, device=torch.cuda.current_device(), requires_grad=False)
|
||||
chunks = [gathered[i * numel : (i + 1) * numel] for i in range(world_size)]
|
||||
dist.all_gather(chunks, tensor, group=gpc.get_group(ParallelMode.TENSOR))
|
||||
return gathered
|
|
@ -0,0 +1,5 @@
|
|||
from .base_scheduler import BaseScheduler
|
||||
from .no_pipeline_scheduler import NonPipelineScheduler
|
||||
from .pipeline_scheduler import InterleavedPipelineScheduler, PipelineScheduler
|
||||
|
||||
__all__ = ["BaseScheduler", "NonPipelineScheduler", "InterleavedPipelineScheduler", "PipelineScheduler"]
|
|
@ -0,0 +1,114 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/engine
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Iterable
|
||||
|
||||
import torch
|
||||
|
||||
from internlm.core.engine import Engine
|
||||
|
||||
|
||||
class BaseScheduler(ABC):
|
||||
"""A basic helper class to control the process of training or evaluation.
|
||||
It mainly composes of forward_backward_step for gradient backward and
|
||||
optimizer_step for parameters update.
|
||||
For the convenience to enable FP16, we aggregate all codes that contain the
|
||||
control of FP16 in class schedule.
|
||||
|
||||
Args:
|
||||
data_process_func (Callable, optional): The preprocessing function which receives a batch of data and arranges
|
||||
them into data and label.
|
||||
"""
|
||||
|
||||
def __init__(self, data_process_func: Callable = None):
|
||||
self.data_process_func = data_process_func
|
||||
|
||||
@abstractmethod
|
||||
def pre_processing(self, engine: Engine):
|
||||
"""To perform actions before running the schedule.
|
||||
|
||||
Args:
|
||||
engine (internlm.core.Engine): InternLM engine for training and inference.
|
||||
"""
|
||||
pass
|
||||
|
||||
def _load_micro_batch(self, data, label, offset, micro_bsz):
|
||||
assert isinstance(data, dict) and isinstance(label, torch.Tensor)
|
||||
micro_batch_data = {k: v[offset : offset + micro_bsz] for k, v in data.items()}
|
||||
micro_batch_label = label[offset : offset + micro_bsz]
|
||||
|
||||
return micro_batch_data, micro_batch_label
|
||||
|
||||
@abstractmethod
|
||||
def forward_backward_step(
|
||||
self,
|
||||
engine: Engine,
|
||||
data_iter: Iterable,
|
||||
forward_only: bool,
|
||||
return_loss: bool = True,
|
||||
return_output_label: bool = True,
|
||||
):
|
||||
"""The process function over a batch of dataset for training or evaluation.
|
||||
|
||||
Args:
|
||||
engine (internlm.core.Engine): InternLM engine for training and inference.
|
||||
data_iter (Iterable): Data iterator from which get a batch of data, obtained by calling iter(dataloader).
|
||||
forward_only (bool): If True, the process won't include backward.
|
||||
return_loss (bool, optional): If False, the loss won't be returned.
|
||||
return_output_label (bool, optional): If False, the output and label won't be returned.
|
||||
"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _call_engine(engine: Engine, inputs: Any):
|
||||
"""Calls the engine with the given inputs.
|
||||
|
||||
Args:
|
||||
engine (internlm.core.Engine): InternLM engine for training and inference.
|
||||
inputs (Any): The inputs to the engine, can be of type torch.Tensor, list, tuple, or dict.
|
||||
"""
|
||||
if isinstance(inputs, torch.Tensor):
|
||||
return engine(inputs)
|
||||
elif isinstance(inputs, (list, tuple)):
|
||||
return engine(*inputs)
|
||||
elif isinstance(inputs, dict):
|
||||
return engine(**inputs)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Expected engine inputs to be of type torch.Tensor, list, tuple, or dict, but got {type(inputs)}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _call_engine_criterion(engine: Engine, outputs: Any, labels: Any):
|
||||
"""Calls the engine's criterion with the given outputs and labels.
|
||||
|
||||
Args:
|
||||
engine (internlm.core.Engine): InternLM engine for training and inference.
|
||||
outputs (Any): The outputs from the model, can be of type torch.Tensor, list, tuple, or dict.
|
||||
labels (Any): The labels for the outputs, can be of type torch.Tensor, list, tuple, or dict.
|
||||
"""
|
||||
assert isinstance(
|
||||
outputs, (torch.Tensor, list, tuple, dict)
|
||||
), f"Expect output of model is (torch.Tensor, list, tuple), got {type(outputs)}"
|
||||
if isinstance(outputs, torch.Tensor):
|
||||
outputs = (outputs,)
|
||||
if isinstance(labels, torch.Tensor):
|
||||
labels = (labels,)
|
||||
|
||||
if isinstance(outputs, (tuple, list)) and isinstance(labels, (tuple, list)):
|
||||
return engine.criterion(*outputs, *labels)
|
||||
elif isinstance(outputs, (tuple, list)) and isinstance(labels, dict):
|
||||
return engine.criterion(*outputs, **labels)
|
||||
elif isinstance(outputs, dict) and isinstance(labels, dict):
|
||||
return engine.criterion(**outputs, **labels)
|
||||
elif isinstance(outputs, dict) and isinstance(labels, (list, tuple)):
|
||||
raise ValueError(f"Expected labels to be a dict when the model outputs are dict, but got {type(labels)}")
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Expected model outputs and labels to be of type torch.Tensor ' \
|
||||
'(which is auto-converted to tuple), list, tuple, or dict, ' \
|
||||
'but got {type(outputs)} (model outputs) and {type(labels)} (labels)"
|
||||
)
|
|
@ -2,8 +2,8 @@
|
|||
# -*- encoding: utf-8 -*-
|
||||
|
||||
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/engine
|
||||
|
||||
import inspect
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Iterable
|
||||
|
||||
import torch
|
||||
|
@ -11,101 +11,7 @@ import torch
|
|||
from internlm.core.engine import Engine
|
||||
from internlm.utils.common import conditional_context
|
||||
|
||||
|
||||
class BaseScheduler(ABC):
|
||||
"""A basic helper class to control the process of training or evaluation.
|
||||
It mainly composes of forward_backward_step for gradient backward and
|
||||
optimizer_step for parameters update.
|
||||
For the convenience to enable FP16, we aggregate all codes that contain the
|
||||
control of FP16 in class schedule.
|
||||
|
||||
Args:
|
||||
data_process_func (Callable, optional): The preprocessing function which receives a batch of data and arranges
|
||||
them into data and label.
|
||||
"""
|
||||
|
||||
def __init__(self, data_process_func: Callable = None):
|
||||
self.data_process_func = data_process_func
|
||||
|
||||
@abstractmethod
|
||||
def pre_processing(self, engine: Engine):
|
||||
"""To perform actions before running the schedule.
|
||||
|
||||
Args:
|
||||
engine (internlm.core.Engine): InternLM engine for training and inference.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def forward_backward_step(
|
||||
self,
|
||||
engine: Engine,
|
||||
data_iter: Iterable,
|
||||
forward_only: bool,
|
||||
return_loss: bool = True,
|
||||
return_output_label: bool = True,
|
||||
):
|
||||
"""The process function over a batch of dataset for training or evaluation.
|
||||
|
||||
Args:
|
||||
engine (internlm.core.Engine): InternLM engine for training and inference.
|
||||
data_iter (Iterable): Data iterator from which get a batch of data, obtained by calling iter(dataloader).
|
||||
forward_only (bool): If True, the process won't include backward.
|
||||
return_loss (bool, optional): If False, the loss won't be returned.
|
||||
return_output_label (bool, optional): If False, the output and label won't be returned.
|
||||
"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _call_engine(engine: Engine, inputs: Any):
|
||||
"""Calls the engine with the given inputs.
|
||||
|
||||
Args:
|
||||
engine (internlm.core.Engine): InternLM engine for training and inference.
|
||||
inputs (Any): The inputs to the engine, can be of type torch.Tensor, list, tuple, or dict.
|
||||
"""
|
||||
if isinstance(inputs, torch.Tensor):
|
||||
return engine(inputs)
|
||||
elif isinstance(inputs, (list, tuple)):
|
||||
return engine(*inputs)
|
||||
elif isinstance(inputs, dict):
|
||||
return engine(**inputs)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Expected engine inputs to be of type torch.Tensor, list, tuple, or dict, but got {type(inputs)}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _call_engine_criterion(engine: Engine, outputs: Any, labels: Any):
|
||||
"""Calls the engine's criterion with the given outputs and labels.
|
||||
|
||||
Args:
|
||||
engine (internlm.core.Engine): InternLM engine for training and inference.
|
||||
outputs (Any): The outputs from the model, can be of type torch.Tensor, list, tuple, or dict.
|
||||
labels (Any): The labels for the outputs, can be of type torch.Tensor, list, tuple, or dict.
|
||||
"""
|
||||
assert isinstance(
|
||||
outputs, (torch.Tensor, list, tuple, dict)
|
||||
), f"Expect output of model is (torch.Tensor, list, tuple), got {type(outputs)}"
|
||||
if isinstance(outputs, torch.Tensor):
|
||||
outputs = (outputs,)
|
||||
if isinstance(labels, torch.Tensor):
|
||||
labels = (labels,)
|
||||
|
||||
if isinstance(outputs, (tuple, list)) and isinstance(labels, (tuple, list)):
|
||||
return engine.criterion(*outputs, *labels)
|
||||
elif isinstance(outputs, (tuple, list)) and isinstance(labels, dict):
|
||||
return engine.criterion(*outputs, **labels)
|
||||
elif isinstance(outputs, dict) and isinstance(labels, dict):
|
||||
return engine.criterion(**outputs, **labels)
|
||||
elif isinstance(outputs, dict) and isinstance(labels, (list, tuple)):
|
||||
raise ValueError(f"Expected labels to be a dict when the model outputs are dict, but got {type(labels)}")
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Expected model outputs and labels to be of type torch.Tensor ' \
|
||||
'(which is auto-converted to tuple), list, tuple, or dict, ' \
|
||||
'but got {type(outputs)} (model outputs) and {type(labels)} (labels)"
|
||||
)
|
||||
from .base_scheduler import BaseScheduler
|
||||
|
||||
|
||||
class NonPipelineScheduler(BaseScheduler):
|
||||
|
@ -161,12 +67,10 @@ class NonPipelineScheduler(BaseScheduler):
|
|||
data (Any): The data to be loaded.
|
||||
label (Any): The label to be loaded.
|
||||
"""
|
||||
_data = {
|
||||
k: v[self._grad_accum_offset : self._grad_accum_offset + self._grad_accum_batch_size]
|
||||
for k, v in data.items()
|
||||
}
|
||||
_label = label[self._grad_accum_offset : self._grad_accum_offset + self._grad_accum_batch_size]
|
||||
|
||||
_data, _label = self._load_micro_batch(
|
||||
data=data, label=label, offset=self._grad_accum_offset, micro_bsz=self._grad_accum_batch_size
|
||||
)
|
||||
self._grad_accum_offset += self._grad_accum_batch_size
|
||||
|
||||
return _data, _label
|
|
@ -0,0 +1,844 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/engine
|
||||
|
||||
import inspect
|
||||
from contextlib import contextmanager
|
||||
from typing import Callable, List, Tuple, Union
|
||||
|
||||
import torch.cuda
|
||||
|
||||
import internlm.core.communication as comm
|
||||
from internlm.core.context import ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.core.naive_amp import NaiveAMPModel
|
||||
from internlm.utils.common import get_current_device, move_to_device
|
||||
from internlm.utils.logger import get_logger
|
||||
from internlm.utils.megatron_timers import megatron_timer as timer
|
||||
|
||||
from .base_scheduler import BaseScheduler
|
||||
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
||||
def get_tensor_shape():
|
||||
if hasattr(gpc.config, "TENSOR_SHAPE"):
|
||||
return gpc.config.TENSOR_SHAPE
|
||||
|
||||
if not gpc.is_initialized(ParallelMode.PIPELINE):
|
||||
return None
|
||||
|
||||
if hasattr(gpc.config, "SEQ_LEN") and hasattr(gpc.config.data, "micro_bsz") and hasattr(gpc.config, "HIDDEN_SIZE"):
|
||||
tensor_shape = (
|
||||
gpc.config.SEQ_LEN * gpc.config.data["micro_bsz"],
|
||||
gpc.config.HIDDEN_SIZE,
|
||||
)
|
||||
return tensor_shape
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def pack_return_tensors(return_tensors):
|
||||
output, label = tuple(zip(*return_tensors))
|
||||
if isinstance(output[0], torch.Tensor):
|
||||
output = torch.cat(output, dim=0)
|
||||
elif isinstance(output[0], (list, tuple)):
|
||||
output = tuple(torch.cat(tensors, dim=0) for tensors in zip(*output))
|
||||
else:
|
||||
raise TypeError("Output of model must be tensor or list/tuple of tensors")
|
||||
if isinstance(label[0], torch.Tensor):
|
||||
label = torch.cat(label, dim=0)
|
||||
else:
|
||||
merged_label = {k: [] for k in label[0].keys()}
|
||||
for d in label:
|
||||
for k, v in d.items():
|
||||
merged_label[k].append(v)
|
||||
label = {k: torch.cat(v, dim=0) for k, v in merged_label.items()}
|
||||
return output, label
|
||||
|
||||
|
||||
@contextmanager
|
||||
def switch_virtual_pipeline_parallel_rank(rank):
|
||||
prev_rank = gpc.virtual_pipeline_parallel_rank
|
||||
try:
|
||||
gpc.set_virtual_pipeline_parallel_rank(rank)
|
||||
yield
|
||||
finally:
|
||||
gpc.set_virtual_pipeline_parallel_rank(prev_rank)
|
||||
|
||||
|
||||
class PipelineScheduler(BaseScheduler):
|
||||
"""A helper schedule class for pipeline parallelism running environment.
|
||||
It uses non-interleaved 1F1B strategy. Other properties are similar as
|
||||
:class:`NonPipelineSchedule`.
|
||||
|
||||
Args:
|
||||
num_microbatches (int): The number of microbatches.
|
||||
data_process_func (Callable, optional):
|
||||
The post processing function which receives a micro batch of data, and it will be executed
|
||||
in `load_micro_batch`.
|
||||
tensor_shape (torch.Size, optional): Specified shape in pipeline communication.
|
||||
scatter_gather_tensors (bool, optional):
|
||||
If set to `True`, communication will be reduced over pipeline when using 1D tensor parallelization.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_microbatches,
|
||||
dtype=torch.float,
|
||||
data_process_func: Callable = None,
|
||||
tensor_shape: Union[torch.Size, List[int], Tuple[int]] = None,
|
||||
scatter_gather_tensors: bool = False,
|
||||
):
|
||||
super().__init__(data_process_func=data_process_func)
|
||||
|
||||
assert num_microbatches > 0, f"expected num_microbatches to be larger then 1, but got {num_microbatches}"
|
||||
|
||||
self.num_microbatches = num_microbatches
|
||||
self.dtype = dtype
|
||||
assert not isinstance(
|
||||
tensor_shape, int
|
||||
), "tensor_shape type should be one of Union[torch.Size, List[int], Tuple[int]]."
|
||||
if tensor_shape is None:
|
||||
self.tensor_shape = tensor_shape
|
||||
elif isinstance(tensor_shape, torch.Size):
|
||||
self.tensor_shape = tensor_shape
|
||||
else:
|
||||
self.tensor_shape = torch.Size(tensor_shape)
|
||||
self.scatter_gather_tensors = False
|
||||
if gpc.is_initialized(ParallelMode.TENSOR) and gpc.get_world_size(ParallelMode.TENSOR) > 1:
|
||||
self.scatter_gather_tensors = scatter_gather_tensors
|
||||
|
||||
# cache for the batch data
|
||||
self.batch_data = None
|
||||
|
||||
def load_batch(self, engine, data_iter):
|
||||
# Pipeline schedule just puts data in memory
|
||||
batch_data, self.batch_size = engine.load_batch(data_iter, to_gpu=False)
|
||||
self.batch_data, self.batch_label = batch_data
|
||||
self.microbatch_offset = 0
|
||||
assert self.batch_size % self.num_microbatches == 0, "Batch size should divided by the number of microbatches"
|
||||
self.microbatch_size = self.batch_size // self.num_microbatches
|
||||
|
||||
def load_micro_batch(self):
|
||||
mciro_batch_data, micro_batch_label = self._load_micro_batch(
|
||||
data=self.batch_data, label=self.batch_label, offset=self.microbatch_offset, micro_bsz=self.microbatch_size
|
||||
)
|
||||
self.microbatch_offset += self.microbatch_size
|
||||
|
||||
# unpack data process
|
||||
# TODO by xyt
|
||||
return move_to_device(mciro_batch_data), move_to_device(micro_batch_label)
|
||||
|
||||
def pre_processing(self, engine):
|
||||
model = engine.model
|
||||
types = set()
|
||||
for param in model.parameters():
|
||||
types.add(param.dtype)
|
||||
assert len(types) == 1, f"Mixed types of parameter detected, {types}"
|
||||
_dtype = types.pop()
|
||||
self.dtype = _dtype
|
||||
|
||||
@staticmethod
|
||||
def _call_engine(model, data): # pylint: disable=W0237
|
||||
if data is not None:
|
||||
if isinstance(data, torch.Tensor):
|
||||
return model(data)
|
||||
elif isinstance(data, (list, tuple)):
|
||||
return model(*data)
|
||||
elif isinstance(data, dict):
|
||||
stage_output = None
|
||||
if "stage_output" in data:
|
||||
stage_output = data.pop("stage_output")
|
||||
if stage_output is None:
|
||||
return model(**data)
|
||||
elif isinstance(stage_output, torch.Tensor):
|
||||
return model(stage_output, **data)
|
||||
elif isinstance(stage_output, (tuple, list)):
|
||||
return model(*stage_output, **data)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Expected stage_output to be of type torch.Tensor, list, or tuple, "
|
||||
f"but got {type(stage_output)}"
|
||||
)
|
||||
else:
|
||||
raise TypeError(f"Expected data to be of type torch.Tensor, list, tuple, or dict, but got {type(data)}")
|
||||
|
||||
def _get_data_label_for_current_step(self, stage_output, micro_batch_data, micro_batch_label):
|
||||
if isinstance(micro_batch_data, (tuple, list)):
|
||||
if gpc.is_first_rank(ParallelMode.PIPELINE):
|
||||
# for the first stage, we use the data from the
|
||||
# dataloader output by default
|
||||
data, label = micro_batch_data
|
||||
else:
|
||||
# for non-first stage, we use the output passed
|
||||
# by the previous as the model input
|
||||
data = stage_output
|
||||
_, label = micro_batch_data
|
||||
elif isinstance(micro_batch_data, dict):
|
||||
data = {}
|
||||
data["stage_output"] = stage_output
|
||||
if "label" in micro_batch_data:
|
||||
label = micro_batch_data.pop("label")
|
||||
else:
|
||||
label = micro_batch_label
|
||||
load_data = micro_batch_data
|
||||
data.update(load_data)
|
||||
|
||||
return data, label
|
||||
|
||||
def _forward_step(self, engine, input_obj, return_tensors, return_output_label=True, accum_loss=None, **kwargs):
|
||||
"""Forward step for passed-in model. If it is the first stage, the input tensor
|
||||
is obtained from data_iterator, otherwise the passed-in input_obj is used.
|
||||
Returns output tensor. This is a helper function and can be ignored by users.
|
||||
|
||||
Args:
|
||||
engine (colossalai.engine.Engine): Colossalai engine for training and inference.
|
||||
input_obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Input tensor for this pipeline stage.
|
||||
return_tensors (List[:class:`torch.Tensor`]): A list of tensors to return.
|
||||
return_output_label (bool, optional): Whether returns output labels.
|
||||
accum_loss (optional): Where accumulated loss stores.
|
||||
Returns:
|
||||
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: output or the loss value of the current
|
||||
pipeline stage.
|
||||
"""
|
||||
micro_batch_data, micro_batch_label = self.load_micro_batch()
|
||||
|
||||
data, label = self._get_data_label_for_current_step(input_obj, micro_batch_data, micro_batch_label)
|
||||
timer("fwd").start()
|
||||
output_obj = self._call_engine(engine.model, data)
|
||||
timer("fwd").stop()
|
||||
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||
timer("post_fn").start()
|
||||
post_func = kwargs.get("post_fn")
|
||||
if post_func is not None:
|
||||
post_func(output_obj, label)
|
||||
timer("post_fn").stop()
|
||||
|
||||
if return_output_label:
|
||||
return_tensors.append((output_obj, label))
|
||||
if accum_loss is not None:
|
||||
timer("cal_loss").start()
|
||||
loss_reduced = self._call_engine_criterion(engine, output_obj, label) / self.num_microbatches
|
||||
accum_loss.add_(loss_reduced.detach())
|
||||
timer("cal_loss").stop()
|
||||
return loss_reduced
|
||||
else:
|
||||
# forward only, it's useless since backward is not needed
|
||||
return output_obj
|
||||
else:
|
||||
return output_obj
|
||||
|
||||
def _backward_step(self, engine, input_obj, output_obj, output_obj_grad):
|
||||
"""Backward step through the passed-in output tensor. If it is the last stage, the
|
||||
output_obj_grad is None, otherwise it is the gradients with respect to stage's output tensor.
|
||||
Returns the gradients with respect to the input tensor (None if first stage).
|
||||
This is a helper function and can be ignored by users.
|
||||
|
||||
Args:
|
||||
engine (colossalai.engine.Engine): Colossalai engine for training and inference.
|
||||
input_obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): input tensor for this pipeline stage.
|
||||
output_obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): output tensor for this
|
||||
pipeline stage.
|
||||
output_obj_grad (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): gradient of output tensor for
|
||||
this pipeline stage.
|
||||
|
||||
Returns:
|
||||
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: gradient of input tensor.
|
||||
"""
|
||||
|
||||
# Retain the grad on the input_obj.
|
||||
if input_obj is not None:
|
||||
if isinstance(input_obj, torch.Tensor):
|
||||
input_obj.retain_grad()
|
||||
else:
|
||||
for in_tensor in input_obj:
|
||||
if in_tensor is not None:
|
||||
in_tensor.retain_grad()
|
||||
timer("bwd").start()
|
||||
# Backward pass.
|
||||
if output_obj_grad is None:
|
||||
engine.backward(output_obj)
|
||||
else:
|
||||
engine.backward_by_grad(output_obj, output_obj_grad)
|
||||
timer("bwd").stop()
|
||||
# Collect the grad of the input_obj.
|
||||
input_obj_grad = None
|
||||
if input_obj is not None:
|
||||
if isinstance(input_obj, torch.Tensor):
|
||||
input_obj_grad = input_obj.grad
|
||||
else:
|
||||
input_obj_grad = []
|
||||
for in_tensor in input_obj:
|
||||
input_obj_grad.append(in_tensor.grad)
|
||||
|
||||
return input_obj_grad
|
||||
|
||||
def forward_backward_step(
|
||||
self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True, **kwargs
|
||||
):
|
||||
"""Runs non-interleaved 1F1B schedule, with communication between pipeline stages.
|
||||
Returns a tuple with losses if the last stage, an empty tuple otherwise.
|
||||
|
||||
Args:
|
||||
engine (colossalai.engine.Engine): Colossalai engine for training and inference.
|
||||
data_iter (Iterable): Dataloader as the form of an iterator, obtained by calling iter(dataloader).
|
||||
forward_only (bool, optional):
|
||||
Whether run forward step only. Default is false. If true, no backward will be run.
|
||||
return_loss (bool, optional): Whether returns the loss value. Default is true.
|
||||
return_output_label (bool, optional): If False, the output and label won't be returned.
|
||||
Returns:
|
||||
Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None.
|
||||
"""
|
||||
|
||||
assert (
|
||||
forward_only or return_loss
|
||||
), "The argument 'return_loss' has to be True when 'forward_only' is False, but got False."
|
||||
self.load_batch(engine, data_iter)
|
||||
num_warmup_microbatches = (
|
||||
gpc.get_world_size(ParallelMode.PIPELINE) - gpc.get_local_rank(ParallelMode.PIPELINE) - 1
|
||||
)
|
||||
num_warmup_microbatches = min(num_warmup_microbatches, self.num_microbatches)
|
||||
num_microbatches_remaining = self.num_microbatches - num_warmup_microbatches
|
||||
|
||||
# only the last micro batch backward need to reduce gradients
|
||||
engine.optimizer.skip_grad_reduce = True
|
||||
|
||||
# Input, output tensors only need to be saved when doing backward passes
|
||||
input_objs = None
|
||||
output_objs = None
|
||||
if not forward_only:
|
||||
input_objs = []
|
||||
output_objs = []
|
||||
return_tensors = []
|
||||
if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True):
|
||||
accum_loss = torch.zeros(1, device=get_current_device())
|
||||
else:
|
||||
accum_loss = None
|
||||
# Used for tensor meta information communication
|
||||
ft_shapes = self.tensor_shape
|
||||
bt_shapes = None
|
||||
fs_checker = self.tensor_shape is None
|
||||
|
||||
# Run warmup forward passes.
|
||||
for i in range(num_warmup_microbatches):
|
||||
if not gpc.is_first_rank(ParallelMode.PIPELINE):
|
||||
ft_shapes = comm.recv_obj_meta(ft_shapes)
|
||||
input_obj = comm.recv_forward(
|
||||
ft_shapes, dtype=self.dtype, scatter_gather_tensors=self.scatter_gather_tensors
|
||||
)
|
||||
output_obj = self._forward_step(
|
||||
engine,
|
||||
input_obj,
|
||||
return_tensors,
|
||||
return_output_label=return_output_label,
|
||||
accum_loss=accum_loss,
|
||||
**kwargs,
|
||||
)
|
||||
if not gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||
if isinstance(output_obj, torch.Tensor):
|
||||
bt_shapes = output_obj.shape
|
||||
else:
|
||||
bt_shapes = []
|
||||
for out_tensor in output_obj:
|
||||
bt_shapes.append(out_tensor.shape)
|
||||
fs_checker = comm.send_obj_meta(output_obj, fs_checker)
|
||||
comm.send_forward(output_obj, scatter_gather_tensors=self.scatter_gather_tensors)
|
||||
|
||||
if not forward_only:
|
||||
input_objs.append(input_obj)
|
||||
output_objs.append(output_obj)
|
||||
|
||||
# Before running 1F1B, need to receive first forward tensor.
|
||||
# If all microbatches are run in warmup / cooldown phase, then no need to
|
||||
# receive this tensor here.
|
||||
if num_microbatches_remaining > 0:
|
||||
if not gpc.is_first_rank(ParallelMode.PIPELINE):
|
||||
ft_shapes = comm.recv_obj_meta(ft_shapes)
|
||||
input_obj = comm.recv_forward(
|
||||
ft_shapes, dtype=self.dtype, scatter_gather_tensors=self.scatter_gather_tensors
|
||||
)
|
||||
|
||||
# Run 1F1B in steady state.
|
||||
for i in range(num_microbatches_remaining):
|
||||
last_iteration = i == (num_microbatches_remaining - 1)
|
||||
|
||||
output_obj = self._forward_step(
|
||||
engine,
|
||||
input_obj,
|
||||
return_tensors,
|
||||
return_output_label=return_output_label,
|
||||
accum_loss=accum_loss,
|
||||
**kwargs,
|
||||
)
|
||||
if forward_only:
|
||||
comm.send_forward(output_obj, scatter_gather_tensors=self.scatter_gather_tensors)
|
||||
|
||||
if not last_iteration:
|
||||
input_obj = comm.recv_forward(
|
||||
ft_shapes, dtype=self.dtype, scatter_gather_tensors=self.scatter_gather_tensors
|
||||
)
|
||||
|
||||
else:
|
||||
output_obj_grad = comm.send_forward_recv_backward(
|
||||
output_obj, bt_shapes, dtype=self.dtype, scatter_gather_tensors=self.scatter_gather_tensors
|
||||
)
|
||||
|
||||
# Add input_obj and output_obj to end of list.
|
||||
input_objs.append(input_obj)
|
||||
output_objs.append(output_obj)
|
||||
|
||||
# Pop output_obj and output_obj from the start of the list for
|
||||
# the backward pass.
|
||||
input_obj = input_objs.pop(0)
|
||||
output_obj = output_objs.pop(0)
|
||||
|
||||
if num_warmup_microbatches == 0 and last_iteration:
|
||||
engine.optimizer.skip_grad_reduce = False
|
||||
|
||||
input_obj_grad = self._backward_step(engine, input_obj, output_obj, output_obj_grad)
|
||||
|
||||
if last_iteration:
|
||||
input_obj = None
|
||||
comm.send_backward(input_obj_grad, scatter_gather_tensors=self.scatter_gather_tensors)
|
||||
else:
|
||||
input_obj = comm.send_backward_recv_forward(
|
||||
input_obj_grad, ft_shapes, dtype=self.dtype, scatter_gather_tensors=self.scatter_gather_tensors
|
||||
)
|
||||
|
||||
# Run cooldown backward passes.
|
||||
if not forward_only:
|
||||
for i in range(num_warmup_microbatches):
|
||||
input_obj = input_objs.pop(0)
|
||||
output_obj = output_objs.pop(0)
|
||||
|
||||
output_obj_grad = comm.recv_backward(
|
||||
bt_shapes, dtype=self.dtype, scatter_gather_tensors=self.scatter_gather_tensors
|
||||
)
|
||||
|
||||
if num_warmup_microbatches > 0 and i == num_warmup_microbatches - 1:
|
||||
engine.optimizer.skip_grad_reduce = False
|
||||
|
||||
input_obj_grad = self._backward_step(engine, input_obj, output_obj, output_obj_grad)
|
||||
|
||||
comm.send_backward(input_obj_grad, scatter_gather_tensors=self.scatter_gather_tensors)
|
||||
|
||||
if len(return_tensors) > 0:
|
||||
output, label = pack_return_tensors(return_tensors)
|
||||
return output, label, accum_loss
|
||||
else:
|
||||
return None, None, accum_loss
|
||||
|
||||
|
||||
class InterleavedPipelineScheduler(PipelineScheduler):
|
||||
"""
|
||||
Interleaved Pipeline Scheduler.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_microbatches: int,
|
||||
num_model_chunks: int,
|
||||
dtype=torch.float,
|
||||
data_process_func: Callable = None,
|
||||
tensor_shape: Union[torch.Size, List[int], Tuple[int]] = None,
|
||||
scatter_gather_tensors: bool = False,
|
||||
):
|
||||
"""A helper schedule class for pipeline parallelism running environment.
|
||||
It uses interleaved 1F1B strategy. Other properties are similar as
|
||||
:class:`NonPipelineSchedule`.
|
||||
|
||||
Args:
|
||||
num_microbatches (int): The number of microbatches.
|
||||
num_model_chunks (int): The number of model chunks.
|
||||
data_process_func (Callable, optional):
|
||||
The preprocessing function which receives a batch of data, and it will be executed in `load_batch`.
|
||||
tensor_shape (torch.Size, optional): Specified shape in pipeline communication.
|
||||
scatter_gather_tensors (bool, optional):
|
||||
If set to `True`, communication will be reduced over pipeline when using 1D tensor parallelization.
|
||||
"""
|
||||
assert (
|
||||
num_microbatches % gpc.get_world_size(ParallelMode.PIPELINE) == 0
|
||||
), "num_microbatches must be an integer multiple of pipeline parallel world size"
|
||||
assert (
|
||||
isinstance(num_model_chunks, int) and num_model_chunks > 0
|
||||
), f"expected num_model_chunks to be an integer and larger than 0, but got {num_model_chunks}"
|
||||
super().__init__(
|
||||
num_microbatches,
|
||||
dtype=dtype,
|
||||
data_process_func=data_process_func,
|
||||
tensor_shape=tensor_shape,
|
||||
scatter_gather_tensors=scatter_gather_tensors,
|
||||
)
|
||||
gpc.set_virtual_pipeline_parallel_size(num_model_chunks)
|
||||
gpc.set_virtual_pipeline_parallel_rank(0)
|
||||
self.num_model_chunks = num_model_chunks
|
||||
|
||||
def pre_processing(self, engine):
|
||||
for model in engine.model:
|
||||
if isinstance(model, NaiveAMPModel):
|
||||
model = model.model
|
||||
sig = inspect.signature(model.forward)
|
||||
for p in sig.parameters.values():
|
||||
assert p.kind != inspect.Parameter.VAR_POSITIONAL, "*args is not supported"
|
||||
|
||||
def load_batch(self, engine, data_iter):
|
||||
super().load_batch(engine, data_iter)
|
||||
# overwrite microbatch_offset, since model chunks load the same microbatch, and should tract the offset
|
||||
self.microbatch_offset = [0 for _ in range(self.num_model_chunks)]
|
||||
|
||||
def load_micro_batch(self, model_chunk_id):
|
||||
mciro_batch_data, micro_batch_label = self._load_micro_batch(
|
||||
data=self.batch_data,
|
||||
label=self.batch_label,
|
||||
offset=self.microbatch_offset[model_chunk_id],
|
||||
micro_bsz=self.microbatch_size,
|
||||
)
|
||||
self.microbatch_offset[model_chunk_id] += self.microbatch_size
|
||||
return move_to_device(mciro_batch_data), move_to_device(micro_batch_label)
|
||||
|
||||
def _forward_step( # pylint: disable=W0237
|
||||
self, engine, model_chunk_id, input_obj, return_tensors, return_output_label=True, accum_loss=None, **kwargs
|
||||
):
|
||||
"""Forward step for passed-in model. If it is the first stage, the input tensor
|
||||
is obtained from data_iterator, otherwise the passed-in input_obj is used.
|
||||
Returns output tensor. This is a helper function and can be ignored by users.
|
||||
|
||||
Args:
|
||||
engine (colossalai.engine.Engine): Colossalai engine for training and inference.
|
||||
model_chunk_id (int): The id of model chunks.
|
||||
input_obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Input tensor for this pipeline stage.
|
||||
return_tensors (List[:class:`torch.Tensor`]): A list of tensors to return.
|
||||
return_output_label (bool, optional): Whether returns output labels.
|
||||
accum_loss (optional): Where accumulated loss stores.
|
||||
Returns:
|
||||
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: output or the loss value of the current
|
||||
pipeline stage.
|
||||
"""
|
||||
micro_batch_data, micro_batch_label = self.load_micro_batch(model_chunk_id)
|
||||
data, label = self._get_data_label_for_current_step(input_obj, micro_batch_data, micro_batch_label)
|
||||
|
||||
output_obj = self._call_engine(engine.model[model_chunk_id], data)
|
||||
|
||||
if gpc.is_pipeline_last_stage():
|
||||
timer("post_fn").start()
|
||||
post_func = kwargs.get("post_fn")
|
||||
if post_func is not None:
|
||||
post_func(output_obj, label)
|
||||
timer("post_fn").stop()
|
||||
|
||||
if return_output_label:
|
||||
return_tensors.append((output_obj, label))
|
||||
if accum_loss is not None:
|
||||
loss_reduced = self._call_engine_criterion(engine, output_obj, label) / self.num_microbatches
|
||||
accum_loss.add_(loss_reduced.detach())
|
||||
return loss_reduced
|
||||
else:
|
||||
# forward only, it's useless since backward is not needed
|
||||
return output_obj
|
||||
else:
|
||||
return output_obj
|
||||
|
||||
def forward_backward_step(
|
||||
self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True, **kwargs
|
||||
):
|
||||
"""Run interleaved 1F1B schedule (model split into model chunks), with
|
||||
communication between pipeline stages as needed.
|
||||
|
||||
Args:
|
||||
engine (colossalai.engine.Engine): Colossalai engine for training and inference.
|
||||
data_iter (Iterable): Dataloader as the form of an iterator, obtained by calling iter(dataloader).
|
||||
forward_only (bool, optional):
|
||||
Whether run forward step only. Default is false. If true, no backward will be run.
|
||||
return_loss (bool, optional): Whether returns the loss value. Default is true.
|
||||
return_output_label (bool, optional): If False, the output and label won't be returned.
|
||||
|
||||
Returns:
|
||||
Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None.
|
||||
The loss would be returned only in the last stage.
|
||||
"""
|
||||
assert (
|
||||
forward_only or return_loss
|
||||
), "The argument 'return_loss' has to be True when 'forward_only' is False, but got False."
|
||||
self.load_batch(engine, data_iter)
|
||||
model = engine.model
|
||||
input_objs = [[] for _ in range(len(model))]
|
||||
output_objs = [[] for _ in range(len(model))]
|
||||
return_tensors = []
|
||||
if not forward_only:
|
||||
output_obj_grads = [[] for _ in range(len(model))]
|
||||
if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True):
|
||||
accum_loss = torch.zeros(1, device=get_current_device())
|
||||
else:
|
||||
accum_loss = None
|
||||
|
||||
# only the last micro batch backward need to reduce gradients
|
||||
engine.optimizer.skip_grad_reduce = True
|
||||
|
||||
# Used for obj meta information communication
|
||||
input_obj_shapes = [self.tensor_shape for _ in range(len(model))]
|
||||
output_obj_shapes = [None for _ in range(len(model))]
|
||||
send_tensor_shape_flags = [self.tensor_shape is None for _ in range(len(model))]
|
||||
|
||||
pipeline_parallel_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
||||
pipeline_parallel_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||
|
||||
# Compute number of warmup and remaining microbatches.
|
||||
num_model_chunks = len(model)
|
||||
num_microbatches = self.num_microbatches * num_model_chunks
|
||||
all_warmup_microbatches = False
|
||||
if forward_only:
|
||||
num_warmup_microbatches = num_microbatches
|
||||
else:
|
||||
# Run all forward passes and then all backward passes if number of
|
||||
# microbatches is just the number of pipeline stages.
|
||||
# Otherwise, perform (num_model_chunks-1)*pipeline_parallel_size on
|
||||
# all workers, followed by more microbatches after depending on
|
||||
# stage ID (more forward passes for earlier stages, later stages can
|
||||
# immediately start with 1F1B).
|
||||
if self.num_microbatches == pipeline_parallel_size:
|
||||
num_warmup_microbatches = num_microbatches
|
||||
all_warmup_microbatches = True
|
||||
else:
|
||||
num_warmup_microbatches = (pipeline_parallel_size - pipeline_parallel_rank - 1) * 2
|
||||
num_warmup_microbatches += (num_model_chunks - 1) * pipeline_parallel_size
|
||||
num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches)
|
||||
num_microbatches_remaining = num_microbatches - num_warmup_microbatches
|
||||
|
||||
def get_model_chunk_id(microbatch_id, forward):
|
||||
"""Helper method to get the model chunk ID given the iteration number."""
|
||||
microbatch_id_in_group = microbatch_id % (pipeline_parallel_size * num_model_chunks)
|
||||
model_chunk_id = microbatch_id_in_group // pipeline_parallel_size
|
||||
if not forward:
|
||||
model_chunk_id = num_model_chunks - model_chunk_id - 1
|
||||
return model_chunk_id
|
||||
|
||||
def _forward_step_helper(microbatch_id):
|
||||
"""Helper method to run forward step with model split into chunks
|
||||
(run set_virtual_pipeline_model_parallel_rank() before calling
|
||||
forward_step())."""
|
||||
model_chunk_id = get_model_chunk_id(microbatch_id, forward=True)
|
||||
gpc.set_virtual_pipeline_parallel_rank(model_chunk_id)
|
||||
|
||||
# forward step
|
||||
if gpc.is_pipeline_first_stage():
|
||||
if len(input_objs[model_chunk_id]) == len(output_objs[model_chunk_id]):
|
||||
input_objs[model_chunk_id].append(None)
|
||||
input_obj = input_objs[model_chunk_id][-1]
|
||||
output_obj = self._forward_step(
|
||||
engine,
|
||||
model_chunk_id,
|
||||
input_obj,
|
||||
return_tensors,
|
||||
return_output_label=return_output_label,
|
||||
accum_loss=accum_loss,
|
||||
**kwargs,
|
||||
)
|
||||
output_objs[model_chunk_id].append(output_obj)
|
||||
|
||||
# if forward-only, no need to save tensors for a backward pass
|
||||
if forward_only:
|
||||
input_objs[model_chunk_id].pop()
|
||||
output_objs[model_chunk_id].pop()
|
||||
|
||||
return output_obj
|
||||
|
||||
def _backward_step_helper(microbatch_id):
|
||||
"""Helper method to run backward step with model split into chunks
|
||||
(run set_virtual_pipeline_model_parallel_rank() before calling
|
||||
backward_step())."""
|
||||
model_chunk_id = get_model_chunk_id(microbatch_id, forward=False)
|
||||
gpc.set_virtual_pipeline_parallel_rank(model_chunk_id)
|
||||
|
||||
if gpc.is_pipeline_last_stage():
|
||||
if len(output_obj_grads[model_chunk_id]) == 0:
|
||||
output_obj_grads[model_chunk_id].append(None)
|
||||
input_obj = input_objs[model_chunk_id].pop(0)
|
||||
output_obj = output_objs[model_chunk_id].pop(0)
|
||||
output_obj_grad = output_obj_grads[model_chunk_id].pop(0)
|
||||
input_obj_grad = self._backward_step(engine, input_obj, output_obj, output_obj_grad)
|
||||
|
||||
return input_obj_grad
|
||||
|
||||
# Run warmup forward passes.
|
||||
gpc.set_virtual_pipeline_parallel_rank(0)
|
||||
if not gpc.is_pipeline_first_stage():
|
||||
input_obj_shapes[0] = comm.recv_obj_meta(input_obj_shapes[0])
|
||||
input_objs[0].append(
|
||||
comm.recv_forward(input_obj_shapes[0], dtype=self.dtype, scatter_gather_tensors=self.scatter_gather_tensors)
|
||||
)
|
||||
|
||||
for k in range(num_warmup_microbatches):
|
||||
model_chunk_id = get_model_chunk_id(k, forward=True)
|
||||
output_obj = _forward_step_helper(k)
|
||||
if not gpc.is_pipeline_last_stage():
|
||||
if isinstance(output_obj, torch.Tensor):
|
||||
output_obj_shapes[model_chunk_id] = output_obj.shape
|
||||
else:
|
||||
output_obj_shapes[model_chunk_id] = []
|
||||
for out_tensor in output_obj:
|
||||
output_obj_shapes[model_chunk_id].append(out_tensor.shape)
|
||||
send_tensor_shape_flags[model_chunk_id] = comm.send_obj_meta(
|
||||
output_obj, send_tensor_shape_flags[model_chunk_id]
|
||||
)
|
||||
# Determine if tensor should be received from previous stage.
|
||||
next_forward_model_chunk_id = get_model_chunk_id(k + 1, forward=True)
|
||||
recv_prev = True
|
||||
if gpc.is_pipeline_first_stage(ignore_virtual=True):
|
||||
if next_forward_model_chunk_id == 0:
|
||||
recv_prev = False
|
||||
if k == (num_microbatches - 1):
|
||||
recv_prev = False
|
||||
|
||||
# Don't send tensor downstream if on last stage.
|
||||
if gpc.is_pipeline_last_stage():
|
||||
output_obj = None
|
||||
|
||||
with switch_virtual_pipeline_parallel_rank(next_forward_model_chunk_id):
|
||||
if not gpc.is_pipeline_first_stage():
|
||||
input_obj_shapes[next_forward_model_chunk_id] = comm.recv_obj_meta(
|
||||
input_obj_shapes[next_forward_model_chunk_id]
|
||||
)
|
||||
# Send and receive tensors as appropriate (send tensors computed
|
||||
# in this iteration; receive tensors for next iteration).
|
||||
input_shape = input_obj_shapes[next_forward_model_chunk_id] if recv_prev else None
|
||||
if k == (num_warmup_microbatches - 1) and not forward_only and not all_warmup_microbatches:
|
||||
input_obj_grad = None
|
||||
recv_next = True
|
||||
if gpc.is_pipeline_last_stage(ignore_virtual=True):
|
||||
recv_next = False
|
||||
output_shape = output_obj_shapes[num_model_chunks - 1] if recv_next else None
|
||||
input_obj, output_obj_grad = comm.send_forward_backward_recv_forward_backward(
|
||||
output_obj,
|
||||
input_obj_grad,
|
||||
input_shape,
|
||||
output_shape,
|
||||
recv_prev=recv_prev,
|
||||
recv_next=recv_next,
|
||||
dtype=self.dtype,
|
||||
scatter_gather_tensors=self.scatter_gather_tensors,
|
||||
)
|
||||
output_obj_grads[num_model_chunks - 1].append(output_obj_grad)
|
||||
else:
|
||||
input_obj = comm.send_forward_recv_forward(
|
||||
output_obj,
|
||||
input_shape,
|
||||
recv_prev=recv_prev,
|
||||
dtype=self.dtype,
|
||||
scatter_gather_tensors=self.scatter_gather_tensors,
|
||||
)
|
||||
input_objs[next_forward_model_chunk_id].append(input_obj)
|
||||
|
||||
# Run 1F1B in steady state.
|
||||
for k in range(num_microbatches_remaining):
|
||||
# Forward pass.
|
||||
forward_k = k + num_warmup_microbatches
|
||||
output_obj = _forward_step_helper(forward_k)
|
||||
|
||||
# Backward pass.
|
||||
backward_k = k
|
||||
if num_warmup_microbatches == 0 and k == num_microbatches_remaining - 1:
|
||||
engine.optimizer.skip_grad_reduce = False
|
||||
input_obj_grad = _backward_step_helper(backward_k)
|
||||
|
||||
# Send output_obj and input_obj_grad, receive input_obj
|
||||
# and output_obj_grad.
|
||||
|
||||
# Determine if current stage has anything to send in either direction,
|
||||
# otherwise set obj to None.
|
||||
forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True)
|
||||
gpc.set_virtual_pipeline_parallel_rank(forward_model_chunk_id)
|
||||
if gpc.is_pipeline_last_stage():
|
||||
output_obj = None
|
||||
|
||||
backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False)
|
||||
gpc.set_virtual_pipeline_parallel_rank(backward_model_chunk_id)
|
||||
if gpc.is_pipeline_first_stage():
|
||||
input_obj_grad = None
|
||||
|
||||
# Determine if peers are sending, and where in data structure to put
|
||||
# received tensors.
|
||||
recv_prev = True
|
||||
if gpc.is_pipeline_first_stage(ignore_virtual=True):
|
||||
# First stage is ahead of last stage by (pipeline_parallel_size - 1).
|
||||
next_forward_model_chunk_id = get_model_chunk_id(forward_k - (pipeline_parallel_size - 1), forward=True)
|
||||
if next_forward_model_chunk_id == (num_model_chunks - 1):
|
||||
recv_prev = False
|
||||
next_forward_model_chunk_id += 1
|
||||
else:
|
||||
next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1, forward=True)
|
||||
|
||||
recv_next = True
|
||||
if gpc.is_pipeline_last_stage(ignore_virtual=True):
|
||||
# Last stage is ahead of first stage by (pipeline_parallel_size - 1).
|
||||
next_backward_model_chunk_id = get_model_chunk_id(
|
||||
backward_k - (pipeline_parallel_size - 1), forward=False
|
||||
)
|
||||
if next_backward_model_chunk_id == 0:
|
||||
recv_next = False
|
||||
next_backward_model_chunk_id -= 1
|
||||
else:
|
||||
next_backward_model_chunk_id = get_model_chunk_id(backward_k + 1, forward=False)
|
||||
|
||||
# If last iteration, don't receive; we already received one extra
|
||||
# before the start of the for loop.
|
||||
if k == (num_microbatches_remaining - 1):
|
||||
recv_prev = False
|
||||
|
||||
input_shape = input_obj_shapes[next_forward_model_chunk_id] if recv_prev else None
|
||||
output_shape = output_obj_shapes[next_backward_model_chunk_id] if recv_next else None
|
||||
# Communicate objs.
|
||||
input_obj, output_obj_grad = comm.send_forward_backward_recv_forward_backward(
|
||||
output_obj,
|
||||
input_obj_grad,
|
||||
input_shape,
|
||||
output_shape,
|
||||
recv_prev=recv_prev,
|
||||
recv_next=recv_next,
|
||||
dtype=self.dtype,
|
||||
scatter_gather_tensors=self.scatter_gather_tensors,
|
||||
)
|
||||
|
||||
# Put input_obj and output_obj_grad in data structures in the
|
||||
# right location.
|
||||
if recv_prev:
|
||||
input_objs[next_forward_model_chunk_id].append(input_obj)
|
||||
if recv_next:
|
||||
output_obj_grads[next_backward_model_chunk_id].append(output_obj_grad)
|
||||
|
||||
# Run cooldown backward passes (flush out pipeline).
|
||||
if not forward_only:
|
||||
if all_warmup_microbatches:
|
||||
output_obj_grads[num_model_chunks - 1].append(
|
||||
comm.recv_backward(
|
||||
output_obj_shapes[num_model_chunks - 1], scatter_gather_tensors=self.scatter_gather_tensors
|
||||
)
|
||||
)
|
||||
for k in range(num_microbatches_remaining, num_microbatches):
|
||||
if k == num_microbatches - 1:
|
||||
engine.optimizer.skip_grad_reduce = False
|
||||
input_obj_grad = _backward_step_helper(k)
|
||||
next_backward_model_chunk_id = get_model_chunk_id(k + 1, forward=False)
|
||||
recv_next = True
|
||||
if gpc.is_pipeline_last_stage(ignore_virtual=True):
|
||||
if next_backward_model_chunk_id == (num_model_chunks - 1):
|
||||
recv_next = False
|
||||
if k == (num_microbatches - 1):
|
||||
recv_next = False
|
||||
output_shape = output_obj_shapes[next_backward_model_chunk_id] if recv_next else None
|
||||
output_obj_grads[next_backward_model_chunk_id].append(
|
||||
comm.send_backward_recv_backward(
|
||||
input_obj_grad,
|
||||
output_shape,
|
||||
recv_next=recv_next,
|
||||
dtype=self.dtype,
|
||||
scatter_gather_tensors=self.scatter_gather_tensors,
|
||||
)
|
||||
)
|
||||
|
||||
if len(return_tensors) > 0:
|
||||
output, label = pack_return_tensors(return_tensors)
|
||||
return output, label, accum_loss
|
||||
else:
|
||||
return None, None, accum_loss
|
|
@ -7,7 +7,12 @@ import json
|
|||
from typing import Iterable, Optional
|
||||
|
||||
from internlm.core.engine import Engine
|
||||
from internlm.core.no_pipeline_scheduler import BaseScheduler, NonPipelineScheduler
|
||||
from internlm.core.scheduler import (
|
||||
BaseScheduler,
|
||||
InterleavedPipelineScheduler,
|
||||
NonPipelineScheduler,
|
||||
PipelineScheduler,
|
||||
)
|
||||
|
||||
|
||||
class TrainState:
|
||||
|
@ -112,8 +117,7 @@ class Trainer:
|
|||
), f"expected schedule to be of type BaseSchedule, but got {type(schedule)}"
|
||||
self._schedule = schedule
|
||||
|
||||
if self.uses_pipeline:
|
||||
self._schedule.pre_processing(self)
|
||||
self._schedule.pre_processing(self._engine)
|
||||
|
||||
@property
|
||||
def engine(self):
|
||||
|
@ -126,7 +130,7 @@ class Trainer:
|
|||
@property
|
||||
def uses_pipeline(self):
|
||||
"""Returns whether the pipeline parallel is used or not."""
|
||||
return False
|
||||
return isinstance(self._schedule, (PipelineScheduler, InterleavedPipelineScheduler))
|
||||
|
||||
def train(self):
|
||||
self._engine.train()
|
||||
|
|
|
@ -219,11 +219,6 @@ class StaticBatchSampler:
|
|||
assert (
|
||||
batch_size - self.start_bsz
|
||||
) % self.bsz_incre == 0, f"{batch_size} - {self.start_bsz} should be multiple of {self.bsz_incre}"
|
||||
assert (
|
||||
self.start_bsz // micro_bsz >= 4
|
||||
), f"Must have more start samples:`{self.start_bsz}` with micro_bsz:\
|
||||
`{micro_bsz}`, so that the pipeline can run correctly"
|
||||
|
||||
assert batch_size % micro_bsz == 0, f"batch_size({batch_size}) should be multiple of micro_bsz({micro_bsz})"
|
||||
assert (
|
||||
self.start_bsz % micro_bsz == 0
|
||||
|
|
|
@ -11,10 +11,16 @@ from torch.optim.lr_scheduler import _LRScheduler
|
|||
from torch.optim.optimizer import Optimizer
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from internlm.core.context import ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.core.engine import Engine
|
||||
from internlm.core.gradient_handler import PipelineSharedModuleGradientHandler
|
||||
from internlm.core.no_pipeline_scheduler import NonPipelineScheduler
|
||||
from internlm.core.scheduler.no_pipeline_scheduler import NonPipelineScheduler
|
||||
from internlm.core.scheduler.pipeline_scheduler import (
|
||||
InterleavedPipelineScheduler,
|
||||
PipelineScheduler,
|
||||
get_tensor_shape,
|
||||
)
|
||||
from internlm.core.trainer import Trainer
|
||||
from internlm.solver.beta2_scheduler import Beta2Scheduler
|
||||
from internlm.solver.optimizer.hybrid_zero_optim import BaseOptimizer
|
||||
|
@ -59,6 +65,8 @@ def initialize_trainer(
|
|||
assert isinstance(optimizer, BaseOptimizer), "optimizer must be instance of BaseOptimizer"
|
||||
|
||||
# gradient handler, only support PipelineSharedModuleGradientHandler now
|
||||
if gpc.is_using_pp():
|
||||
gpc.config.gradient_handler = [dict(type="PipelineSharedModuleGradientHandler")]
|
||||
gradient_handler_cfg = gpc.config.get("gradient_handler", [])
|
||||
gradient_handlers = []
|
||||
assert isinstance(gradient_handler_cfg, list), f"gradient_handler must be list but got {type(gradient_handler_cfg)}"
|
||||
|
@ -67,8 +75,36 @@ def initialize_trainer(
|
|||
handler = PipelineSharedModuleGradientHandler(model=model, optimizer=optimizer)
|
||||
gradient_handlers.append(handler)
|
||||
|
||||
scheduler = NonPipelineScheduler(gradient_accumulation_size=gpc.config.data.gradient_accumulation)
|
||||
# initialize scheduler for trainer
|
||||
scheduler = None
|
||||
if gpc.is_using_pp():
|
||||
gpc.config.NUM_MICRO_BATCHES = gpc.config.data.micro_num
|
||||
tensor_shape = get_tensor_shape()
|
||||
use_interleaved = (
|
||||
hasattr(gpc.config, "model") and hasattr(gpc.config.model, "num_chunks") and gpc.config.model.num_chunks > 1
|
||||
)
|
||||
scatter_gather = gpc.is_initialized(ParallelMode.TENSOR)
|
||||
if use_interleaved:
|
||||
if isinstance(model, nn.Sequential):
|
||||
model = nn.ModuleList([model])
|
||||
scheduler = InterleavedPipelineScheduler(
|
||||
num_microbatches=gpc.config.NUM_MICRO_BATCHES,
|
||||
num_model_chunks=gpc.config.model.num_chunks,
|
||||
dtype=gpc.config.model["dtype"],
|
||||
tensor_shape=tensor_shape,
|
||||
scatter_gather_tensors=scatter_gather,
|
||||
)
|
||||
else:
|
||||
scheduler = PipelineScheduler(
|
||||
num_microbatches=gpc.config.NUM_MICRO_BATCHES,
|
||||
dtype=gpc.config.model["dtype"],
|
||||
tensor_shape=tensor_shape,
|
||||
scatter_gather_tensors=scatter_gather,
|
||||
)
|
||||
else:
|
||||
scheduler = NonPipelineScheduler(gradient_accumulation_size=gpc.config.data.gradient_accumulation)
|
||||
|
||||
# initialize engine for trainer
|
||||
engine = Engine(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
|
|
|
@ -402,9 +402,10 @@ def _build_generic_model_1d(num_layers, num_chunks, device=torch.device("cuda"),
|
|||
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
||||
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||
|
||||
# all_parts = partition_uniform_with_embed2(num_layers, pipeline_size, num_chunks)
|
||||
all_parts = partition_uniform(num_layers, pipeline_size, num_chunks)
|
||||
parts = all_parts[pipeline_rank]
|
||||
if gpc.is_rank_for_log():
|
||||
logger.info(f"The layer sharding is {all_parts}.")
|
||||
|
||||
models = []
|
||||
|
||||
|
|
|
@ -3,11 +3,8 @@
|
|||
|
||||
from functools import partial
|
||||
|
||||
import amp_C
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from apex.multi_tensor_apply import multi_tensor_applier
|
||||
from torch._six import inf
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from internlm.core.context import Config, ParallelMode
|
||||
|
@ -28,32 +25,15 @@ from internlm.solver.optimizer.utils import (
|
|||
split_half_float_double,
|
||||
sync_param,
|
||||
)
|
||||
from internlm.utils.common import get_current_device, get_tensor_norm, move_norm_to_cuda
|
||||
from internlm.utils.common import get_current_device
|
||||
from internlm.utils.logger import get_logger
|
||||
from internlm.utils.megatron_timers import megatron_timer as timer
|
||||
from internlm.utils.parallel import is_model_parallel_parameter
|
||||
|
||||
from .utils import compute_norm
|
||||
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
||||
def calc_l2_norm(grads):
|
||||
norm = 0.0
|
||||
if len(grads) > 0:
|
||||
dummy_overflow_buf = torch.cuda.IntTensor([0])
|
||||
norm, _ = multi_tensor_applier(
|
||||
amp_C.multi_tensor_l2norm, dummy_overflow_buf, [grads], False # no per-parameter norm
|
||||
)
|
||||
return norm
|
||||
|
||||
|
||||
def calc_lp(grads, norm_type):
|
||||
norm = 0.0
|
||||
for grad in grads:
|
||||
grad_norm = torch.norm(grad, norm_type)
|
||||
norm += grad_norm**norm_type
|
||||
return norm
|
||||
|
||||
|
||||
class BaseOptimizer(Optimizer):
|
||||
"""
|
||||
Base Optimizer.
|
||||
|
@ -239,8 +219,8 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
# flag used to skip unnecessary gradient reduce operation when gradient accumulation is enabled.
|
||||
self.skip_grad_reduce = False
|
||||
|
||||
# intialize communication stream for
|
||||
# communication-compuation overlapping
|
||||
# initialize communication stream for
|
||||
# communication-computation overlapping
|
||||
if self._overlap_communication:
|
||||
self._comm_stream = torch.cuda.Stream()
|
||||
|
||||
|
@ -732,87 +712,3 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
|
||||
if "zero_devide_optim_plan" in states:
|
||||
self.params_per_rank_id_dict = states["zero_devide_optim_plan"]
|
||||
|
||||
|
||||
def compute_norm(gradients, parameters, norm_type=2):
|
||||
"""Get the norm
|
||||
Arguments:
|
||||
gradients (Iterable[Tensor]): The gradient value.
|
||||
parameters (Iterable[Tensor]): The parameter each gradient corresponds to.
|
||||
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
|
||||
infinity norm.
|
||||
|
||||
Returns:
|
||||
Total norm of the parameters, need total_norm**(1/norm) before using.
|
||||
"""
|
||||
|
||||
enable_cuda_kernels = gradients[0].device.type == "cuda"
|
||||
# Norm parameters.
|
||||
norm_type = float(norm_type)
|
||||
|
||||
# Calculate norm.
|
||||
if norm_type == inf:
|
||||
total_norm = max(g.data.abs().max() for g in gradients)
|
||||
total_norm_cuda = torch.FloatTensor([float(total_norm)], device=gradients[0].device)
|
||||
# Take max across all model-parallel GPUs.
|
||||
if gpc.get_world_size(ParallelMode.MODEL) > 1:
|
||||
dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.MODEL))
|
||||
total_norm = total_norm_cuda[0].item()
|
||||
else:
|
||||
tensor_parallel_grads = []
|
||||
for g, p in zip(gradients, parameters):
|
||||
# TODO: consider the pipeline shared parameter
|
||||
if (
|
||||
gpc.is_initialized(ParallelMode.PIPELINE)
|
||||
and hasattr(p, "pipeline_shared_module_pg")
|
||||
and dist.get_rank(p.pipeline_shared_module_pg) == 0
|
||||
): # if shared between different pipe, only count o
|
||||
tensor_parallel_grads.append(g.data.float())
|
||||
elif (
|
||||
gpc.is_initialized(ParallelMode.PIPELINE)
|
||||
and hasattr(p, "pipeline_shared_module_pg")
|
||||
and dist.get_rank(p.pipeline_shared_module_pg) != 0
|
||||
):
|
||||
continue
|
||||
elif (
|
||||
gpc.is_initialized(ParallelMode.TENSOR)
|
||||
and not is_model_parallel_parameter(p)
|
||||
and gpc.get_local_rank(ParallelMode.TENSOR) == 0
|
||||
): # if not used in each chunk, such as layernorm
|
||||
tensor_parallel_grads.append(g.data.float())
|
||||
elif is_model_parallel_parameter(p):
|
||||
tensor_parallel_grads.append(g.data.float())
|
||||
elif gpc.get_local_rank(ParallelMode.TENSOR) != 0:
|
||||
continue
|
||||
else:
|
||||
raise RuntimeError("Should not arrive here")
|
||||
|
||||
if norm_type == 2.0 and enable_cuda_kernels:
|
||||
tensor_parallel_norm = calc_l2_norm(tensor_parallel_grads) ** norm_type
|
||||
else:
|
||||
tensor_parallel_norm = calc_lp(tensor_parallel_grads, norm_type)
|
||||
|
||||
# If norm is type of float, then we convert them into torch.Tensor.
|
||||
tensor_parallel_norm = get_tensor_norm(tensor_parallel_norm, enable_cuda_kernels)
|
||||
# If grads are on CPU, the norms is also on CPU. Cast them to CUDA tensors
|
||||
if not enable_cuda_kernels:
|
||||
tensor_parallel_norm = move_norm_to_cuda(tensor_parallel_norm)
|
||||
|
||||
total_norm = tensor_parallel_norm
|
||||
|
||||
# Sum across all model-parallel GPUs.
|
||||
if gpc.is_initialized(ParallelMode.MODEL):
|
||||
dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.MODEL))
|
||||
|
||||
# This is because we use zero1, so we need to use this reduction.
|
||||
# TODO: Check zero group to be a subset of dp group.
|
||||
dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.ZERO1))
|
||||
|
||||
if torch.is_tensor(total_norm):
|
||||
total_norm = total_norm.item()
|
||||
|
||||
# Scale.
|
||||
if total_norm == float("inf") or total_norm == -float("inf"):
|
||||
total_norm = -1
|
||||
|
||||
return total_norm
|
||||
|
|
|
@ -4,14 +4,19 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Optional
|
||||
|
||||
import amp_C
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from apex.multi_tensor_apply import multi_tensor_applier
|
||||
from torch import Tensor
|
||||
from torch._six import inf
|
||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||
|
||||
from internlm.core.context import ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.utils.common import get_tensor_norm, move_norm_to_cuda
|
||||
from internlm.utils.logger import get_logger
|
||||
from internlm.utils.parallel import is_model_parallel_parameter
|
||||
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
@ -150,6 +155,108 @@ def sync_param(flat_tensor, tensor_list):
|
|||
p.data = q.data
|
||||
|
||||
|
||||
def calc_l2_norm(grads):
|
||||
norm = 0.0
|
||||
if len(grads) > 0:
|
||||
dummy_overflow_buf = torch.cuda.IntTensor([0])
|
||||
norm, _ = multi_tensor_applier(
|
||||
amp_C.multi_tensor_l2norm, dummy_overflow_buf, [grads], False # no per-parameter norm
|
||||
)
|
||||
return norm
|
||||
|
||||
|
||||
def calc_lp(grads, norm_type):
|
||||
norm = 0.0
|
||||
for grad in grads:
|
||||
grad_norm = torch.norm(grad, norm_type)
|
||||
norm += grad_norm**norm_type
|
||||
return norm
|
||||
|
||||
|
||||
def compute_norm(gradients, parameters, norm_type=2):
|
||||
"""Get the norm
|
||||
Arguments:
|
||||
gradients (Iterable[Tensor]): The gradient value.
|
||||
parameters (Iterable[Tensor]): The parameter each gradient corresponds to.
|
||||
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
|
||||
infinity norm.
|
||||
|
||||
Returns:
|
||||
Total norm of the parameters, need total_norm**(1/norm) before using.
|
||||
"""
|
||||
|
||||
enable_cuda_kernels = gradients[0].device.type == "cuda"
|
||||
# Norm parameters.
|
||||
norm_type = float(norm_type)
|
||||
|
||||
# Calculate norm.
|
||||
if norm_type == inf:
|
||||
total_norm = max(g.data.abs().max() for g in gradients)
|
||||
total_norm_cuda = torch.FloatTensor([float(total_norm)], device=gradients[0].device)
|
||||
# Take max across all model-parallel GPUs.
|
||||
if gpc.get_world_size(ParallelMode.MODEL) > 1:
|
||||
dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.MODEL))
|
||||
total_norm = total_norm_cuda[0].item()
|
||||
else:
|
||||
tensor_parallel_grads = []
|
||||
for g, p in zip(gradients, parameters):
|
||||
# TODO: consider the pipeline shared parameter
|
||||
if (
|
||||
gpc.is_initialized(ParallelMode.PIPELINE)
|
||||
and hasattr(p, "pipeline_shared_module_pg")
|
||||
and dist.get_rank(p.pipeline_shared_module_pg) == 0
|
||||
): # if shared between different pipe, only count o
|
||||
tensor_parallel_grads.append(g.data.float())
|
||||
elif (
|
||||
gpc.is_initialized(ParallelMode.PIPELINE)
|
||||
and hasattr(p, "pipeline_shared_module_pg")
|
||||
and dist.get_rank(p.pipeline_shared_module_pg) != 0
|
||||
):
|
||||
continue
|
||||
elif (
|
||||
gpc.is_initialized(ParallelMode.TENSOR)
|
||||
and not is_model_parallel_parameter(p)
|
||||
and gpc.get_local_rank(ParallelMode.TENSOR) == 0
|
||||
): # if not used in each chunk, such as layernorm
|
||||
tensor_parallel_grads.append(g.data.float())
|
||||
elif is_model_parallel_parameter(p):
|
||||
tensor_parallel_grads.append(g.data.float())
|
||||
elif gpc.get_local_rank(ParallelMode.TENSOR) != 0:
|
||||
continue
|
||||
else:
|
||||
raise RuntimeError("Should not arrive here")
|
||||
|
||||
if norm_type == 2.0 and enable_cuda_kernels:
|
||||
tensor_parallel_norm = calc_l2_norm(tensor_parallel_grads) ** norm_type
|
||||
else:
|
||||
tensor_parallel_norm = calc_lp(tensor_parallel_grads, norm_type)
|
||||
|
||||
# If norm is type of float, then we convert them into torch.Tensor.
|
||||
tensor_parallel_norm = get_tensor_norm(tensor_parallel_norm, enable_cuda_kernels)
|
||||
# If grads are on CPU, the norms is also on CPU. Cast them to CUDA tensors
|
||||
if not enable_cuda_kernels:
|
||||
tensor_parallel_norm = move_norm_to_cuda(tensor_parallel_norm)
|
||||
|
||||
total_norm = tensor_parallel_norm
|
||||
|
||||
# Sum across all model-parallel GPUs.
|
||||
if gpc.is_initialized(ParallelMode.MODEL):
|
||||
dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.MODEL))
|
||||
|
||||
# This is because we use zero1, so we need to use this reduction.
|
||||
# TODO: Check zero group to be a subset of dp group.
|
||||
dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.ZERO1))
|
||||
|
||||
if torch.is_tensor(total_norm):
|
||||
total_norm = total_norm.item()
|
||||
|
||||
# Scale.
|
||||
if total_norm == float("inf") or total_norm == -float("inf"):
|
||||
total_norm = -1
|
||||
|
||||
return total_norm
|
||||
|
||||
|
||||
class BaseGradScaler(ABC):
|
||||
"""A base class for the gradient scaler.
|
||||
|
||||
|
|
|
@ -81,28 +81,12 @@ def move_to_device(data):
|
|||
data_to_return = []
|
||||
for element in data:
|
||||
if isinstance(element, dict):
|
||||
data_to_return.append(
|
||||
{
|
||||
k: (
|
||||
_move_tensor(v)
|
||||
if k != "inference_params"
|
||||
else v._replace(attention_mask=_move_tensor(v.attention_mask))
|
||||
)
|
||||
for k, v in element.items()
|
||||
}
|
||||
)
|
||||
data_to_return.append({k: _move_tensor(v) for k, v in element.items()})
|
||||
else:
|
||||
data_to_return.append(_move_tensor(element))
|
||||
data = data_to_return
|
||||
elif isinstance(data, dict):
|
||||
data = {
|
||||
k: (
|
||||
_move_tensor(v)
|
||||
if k != "inference_params"
|
||||
else v._replace(attention_mask=_move_tensor(v.attention_mask))
|
||||
)
|
||||
for k, v in data.items()
|
||||
}
|
||||
data = {k: _move_tensor(v) for k, v in data.items()}
|
||||
else:
|
||||
raise TypeError(f"Expected batch data to be of type torch.Tensor, list, tuple, or dict, but got {type(data)}")
|
||||
return data
|
||||
|
|
|
@ -138,11 +138,13 @@ def save_optimizer_checkpoint(optim, state_path):
|
|||
zero_rank = gpc.get_local_rank(ParallelMode.ZERO1)
|
||||
tp_rank = gpc.get_local_rank(ParallelMode.TENSOR)
|
||||
pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||
tp_size = gpc.get_world_size(ParallelMode.TENSOR)
|
||||
pp_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
||||
fp = f"optimizer_tp{tp_rank}_pp{pp_rank}_zo{zero_rank}.pt"
|
||||
|
||||
states = optim.state_dict()
|
||||
if isinstance(optim, HybridZeroOptimizer):
|
||||
if gpc.get_global_rank() < optim.zero_world_size:
|
||||
if gpc.get_global_rank() < optim.zero_world_size * tp_size * pp_size:
|
||||
llm_save(os.path.join(state_path, fp), states)
|
||||
if "zero_devide_optim_plan" in states:
|
||||
params_per_rank_id_dict = states.pop("zero_devide_optim_plan")
|
||||
|
|
|
@ -22,9 +22,9 @@ class Registry:
|
|||
"""Registers a module represented in `module_class`.
|
||||
|
||||
Args:
|
||||
module_class (class): The module to be registered.
|
||||
module_name (str): The name of module to be registered.
|
||||
Returns:
|
||||
class: The module to be registered, so as to use it normally if via importing.
|
||||
function: The module to be registered, so as to use it normally if via importing.
|
||||
Raises:
|
||||
AssertionError: Raises an AssertionError if the module has already been registered before.
|
||||
"""
|
||||
|
|
9
train.py
9
train.py
|
@ -29,7 +29,7 @@ from internlm.data.utils import DATASET_TYPE_IDS_MAP
|
|||
from internlm.model.loss import FlashGPTLMLoss
|
||||
from internlm.solver.beta2_scheduler import Beta2Scheduler
|
||||
from internlm.solver.lr_scheduler import FineTuneCosineAnnealingWarmupLR
|
||||
from internlm.solver.optimizer.hybrid_zero_optim import HybridZeroOptimizer
|
||||
from internlm.solver.optimizer import HybridZeroOptimizer
|
||||
from internlm.utils.common import (
|
||||
BatchSkipper,
|
||||
get_master_node,
|
||||
|
@ -93,10 +93,6 @@ def initialize_model():
|
|||
Returns: The neural network model to be trained or evaluated.
|
||||
"""
|
||||
|
||||
assert (
|
||||
not hasattr(gpc.config.parallel, "pipeline") or gpc.config.parallel.pipeline == 1
|
||||
), "Pipeline parallelism is not supported for now."
|
||||
|
||||
model = MODEL_INITIALIZER.get_module(module_name=gpc.config.model_type)(**(gpc.config.model))
|
||||
model = NaiveAMPModel(
|
||||
model=model,
|
||||
|
@ -314,7 +310,7 @@ def record_current_batch_training_metrics(
|
|||
|
||||
line = ""
|
||||
for key, value in infos.items():
|
||||
line += f"{key}={value},"
|
||||
line += f"{key}={value} "
|
||||
writer.add_scalar(key=key, value=value, step=train_state.step_count)
|
||||
|
||||
logger.info(line)
|
||||
|
@ -466,7 +462,6 @@ def main(args):
|
|||
timer("fwd-bwd").start()
|
||||
_, _, loss = trainer.execute_schedule(batch, forward_only=False, return_loss=True, return_output_label=False)
|
||||
timer("fwd-bwd").stop()
|
||||
assert loss is not None
|
||||
|
||||
# update parameters, and returns (success_update, grad_norm)
|
||||
trainer_result = trainer.step()
|
||||
|
|
Loading…
Reference in New Issue