mirror of https://github.com/InternLM/InternLM
refactor(scheduler): rewrite pipeline scheduler (#138)
* refactor(scheduler): rewrite pipeline scheduler * fix(*): fix pipeline scheduler bugs * fix(*): fix merge bug * feat(*): update codes with todo tag * feat(*): add comments * feat(internlm/core/scheduler): update recv_prev/next logic * feat(utils/evaluation.py): update sche metric hook for valid --------- Co-authored-by: huangting.p <huangting@sensetime.com>pull/155/head
parent
d67be17f96
commit
0268d8eda1
|
@ -117,6 +117,7 @@ model = dict(
|
||||||
norm_type="rmsnorm",
|
norm_type="rmsnorm",
|
||||||
layer_norm_epsilon=1e-5,
|
layer_norm_epsilon=1e-5,
|
||||||
use_flash_attn=True,
|
use_flash_attn=True,
|
||||||
|
num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used.
|
||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
zero1 parallel:
|
zero1 parallel:
|
||||||
|
@ -125,12 +126,14 @@ zero1 parallel:
|
||||||
2. if zero1 == 1, zero is not used, and all dp groups retain the full amount of model parameters.
|
2. if zero1 == 1, zero is not used, and all dp groups retain the full amount of model parameters.
|
||||||
3. zero1 > 1 and zero1 <= dp world size, the world size of zero is a subset of dp world size.
|
3. zero1 > 1 and zero1 <= dp world size, the world size of zero is a subset of dp world size.
|
||||||
For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8.
|
For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8.
|
||||||
pipeline parallel: pipeline parallel size.
|
pipeline parallel (dict):
|
||||||
|
1. size: int, the size of pipeline parallel.
|
||||||
|
2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler.
|
||||||
tensor parallel: tensor parallel size, usually the number of GPUs per node.
|
tensor parallel: tensor parallel size, usually the number of GPUs per node.
|
||||||
"""
|
"""
|
||||||
parallel = dict(
|
parallel = dict(
|
||||||
zero1=8,
|
zero1=8,
|
||||||
pipeline=2,
|
pipeline=dict(size=1, interleaved_overlap=True),
|
||||||
)
|
)
|
||||||
|
|
||||||
cudnn_deterministic = False
|
cudnn_deterministic = False
|
||||||
|
|
|
@ -1,10 +1,13 @@
|
||||||
from .p2p import (
|
from .p2p import (
|
||||||
|
AsynCommunicator,
|
||||||
recv_backward,
|
recv_backward,
|
||||||
recv_forward,
|
recv_forward,
|
||||||
send_backward,
|
send_backward,
|
||||||
|
send_backward_and_recv_next_backward_async,
|
||||||
send_backward_recv_backward,
|
send_backward_recv_backward,
|
||||||
send_backward_recv_forward,
|
send_backward_recv_forward,
|
||||||
send_forward,
|
send_forward,
|
||||||
|
send_forward_and_recv_next_forward_async,
|
||||||
send_forward_backward_recv_forward_backward,
|
send_forward_backward_recv_forward_backward,
|
||||||
send_forward_recv_backward,
|
send_forward_recv_backward,
|
||||||
send_forward_recv_forward,
|
send_forward_recv_forward,
|
||||||
|
@ -23,4 +26,7 @@ __all__ = [
|
||||||
"recv_forward",
|
"recv_forward",
|
||||||
"send_obj_meta",
|
"send_obj_meta",
|
||||||
"recv_obj_meta",
|
"recv_obj_meta",
|
||||||
|
"send_backward_and_recv_next_backward_async",
|
||||||
|
"send_forward_and_recv_next_forward_async",
|
||||||
|
"AsynCommunicator",
|
||||||
]
|
]
|
||||||
|
|
|
@ -207,9 +207,6 @@ def recv_forward(
|
||||||
Returns:
|
Returns:
|
||||||
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input tensor or input tensor list.
|
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input tensor or input tensor list.
|
||||||
"""
|
"""
|
||||||
if gpc.is_pipeline_first_stage():
|
|
||||||
input_tensor = None
|
|
||||||
else:
|
|
||||||
input_tensor, _ = _communicate(
|
input_tensor, _ = _communicate(
|
||||||
recv_prev=True,
|
recv_prev=True,
|
||||||
recv_prev_shape=input_tensor_shape,
|
recv_prev_shape=input_tensor_shape,
|
||||||
|
@ -233,9 +230,6 @@ def recv_backward(
|
||||||
Returns:
|
Returns:
|
||||||
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input gradient tensor or gradident tensor list.
|
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input gradient tensor or gradident tensor list.
|
||||||
"""
|
"""
|
||||||
if gpc.is_pipeline_last_stage():
|
|
||||||
output_tensor_grad = None
|
|
||||||
else:
|
|
||||||
_, output_tensor_grad = _communicate(
|
_, output_tensor_grad = _communicate(
|
||||||
recv_next=True,
|
recv_next=True,
|
||||||
recv_next_shape=output_grad_shape,
|
recv_next_shape=output_grad_shape,
|
||||||
|
@ -253,7 +247,6 @@ def send_forward(output_tensor, next_rank=None, scatter_gather_tensors=False) ->
|
||||||
output_tensor (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent.
|
output_tensor (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent.
|
||||||
next_rank (int, optional): The rank of the recipient of the tensor.
|
next_rank (int, optional): The rank of the recipient of the tensor.
|
||||||
"""
|
"""
|
||||||
if not gpc.is_pipeline_last_stage():
|
|
||||||
_communicate(object_send_next=output_tensor, next_rank=next_rank, scatter_gather_tensors=scatter_gather_tensors)
|
_communicate(object_send_next=output_tensor, next_rank=next_rank, scatter_gather_tensors=scatter_gather_tensors)
|
||||||
|
|
||||||
|
|
||||||
|
@ -264,14 +257,12 @@ def send_backward(input_tensor_grad, prev_rank=None, scatter_gather_tensors=Fals
|
||||||
input_tensor_grad (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent
|
input_tensor_grad (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent
|
||||||
prev_rank (int, optional): The rank of the recipient of the tensor
|
prev_rank (int, optional): The rank of the recipient of the tensor
|
||||||
"""
|
"""
|
||||||
if not gpc.is_pipeline_first_stage():
|
|
||||||
_communicate(
|
_communicate(object_send_prev=input_tensor_grad, prev_rank=prev_rank, scatter_gather_tensors=scatter_gather_tensors)
|
||||||
object_send_prev=input_tensor_grad, prev_rank=prev_rank, scatter_gather_tensors=scatter_gather_tensors
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def send_forward_recv_backward(
|
def send_forward_recv_backward(
|
||||||
output_tensor, output_grad_shape, recv_next=True, next_rank=None, dtype=torch.float, scatter_gather_tensors=False
|
output_tensor, output_grad_shape, next_rank=None, dtype=torch.float, scatter_gather_tensors=False
|
||||||
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
||||||
"""Batched communication operation. Sends the input tensor to the
|
"""Batched communication operation. Sends the input tensor to the
|
||||||
next stage in pipeline, while receives the gradient tensor from the
|
next stage in pipeline, while receives the gradient tensor from the
|
||||||
|
@ -285,24 +276,21 @@ def send_forward_recv_backward(
|
||||||
Returns:
|
Returns:
|
||||||
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input gradient tensor.
|
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input gradient tensor.
|
||||||
"""
|
"""
|
||||||
if gpc.is_pipeline_last_stage():
|
|
||||||
output_tensor_grad = None
|
|
||||||
else:
|
|
||||||
_, output_tensor_grad = _communicate(
|
_, output_tensor_grad = _communicate(
|
||||||
object_send_next=output_tensor,
|
object_send_next=output_tensor,
|
||||||
recv_next=recv_next,
|
recv_next=output_grad_shape is not None,
|
||||||
recv_next_shape=output_grad_shape,
|
recv_next_shape=output_grad_shape,
|
||||||
next_rank=next_rank,
|
next_rank=next_rank,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
scatter_gather_tensors=scatter_gather_tensors,
|
scatter_gather_tensors=scatter_gather_tensors,
|
||||||
)
|
)
|
||||||
|
|
||||||
return output_tensor_grad
|
return output_tensor_grad
|
||||||
|
|
||||||
|
|
||||||
def send_backward_recv_forward(
|
def send_backward_recv_forward(
|
||||||
input_tensor_grad,
|
input_tensor_grad,
|
||||||
input_tensor_shape,
|
input_tensor_shape,
|
||||||
recv_prev=True,
|
|
||||||
prev_rank=None,
|
prev_rank=None,
|
||||||
dtype=torch.float,
|
dtype=torch.float,
|
||||||
scatter_gather_tensors=False,
|
scatter_gather_tensors=False,
|
||||||
|
@ -319,24 +307,21 @@ def send_backward_recv_forward(
|
||||||
Returns:
|
Returns:
|
||||||
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input tensor.
|
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input tensor.
|
||||||
"""
|
"""
|
||||||
if gpc.is_pipeline_first_stage():
|
|
||||||
input_tensor = None
|
|
||||||
else:
|
|
||||||
input_tensor, _ = _communicate(
|
input_tensor, _ = _communicate(
|
||||||
object_send_prev=input_tensor_grad,
|
object_send_prev=input_tensor_grad,
|
||||||
recv_prev=recv_prev,
|
recv_prev=input_tensor_shape is not None,
|
||||||
recv_prev_shape=input_tensor_shape,
|
recv_prev_shape=input_tensor_shape,
|
||||||
prev_rank=prev_rank,
|
prev_rank=prev_rank,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
scatter_gather_tensors=scatter_gather_tensors,
|
scatter_gather_tensors=scatter_gather_tensors,
|
||||||
)
|
)
|
||||||
|
|
||||||
return input_tensor
|
return input_tensor
|
||||||
|
|
||||||
|
|
||||||
def send_forward_recv_forward(
|
def send_forward_recv_forward(
|
||||||
output_tensor,
|
output_tensor,
|
||||||
input_tensor_shape,
|
input_tensor_shape,
|
||||||
recv_prev=True,
|
|
||||||
prev_rank=None,
|
prev_rank=None,
|
||||||
next_rank=None,
|
next_rank=None,
|
||||||
dtype=torch.float,
|
dtype=torch.float,
|
||||||
|
@ -356,7 +341,7 @@ def send_forward_recv_forward(
|
||||||
"""
|
"""
|
||||||
input_tensor, _ = _communicate(
|
input_tensor, _ = _communicate(
|
||||||
object_send_next=output_tensor,
|
object_send_next=output_tensor,
|
||||||
recv_prev=recv_prev,
|
recv_prev=input_tensor_shape is not None,
|
||||||
recv_prev_shape=input_tensor_shape,
|
recv_prev_shape=input_tensor_shape,
|
||||||
prev_rank=prev_rank,
|
prev_rank=prev_rank,
|
||||||
next_rank=next_rank,
|
next_rank=next_rank,
|
||||||
|
@ -369,7 +354,6 @@ def send_forward_recv_forward(
|
||||||
def send_backward_recv_backward(
|
def send_backward_recv_backward(
|
||||||
input_tensor_grad,
|
input_tensor_grad,
|
||||||
output_grad_shape,
|
output_grad_shape,
|
||||||
recv_next=True,
|
|
||||||
prev_rank=None,
|
prev_rank=None,
|
||||||
next_rank=None,
|
next_rank=None,
|
||||||
dtype=torch.float,
|
dtype=torch.float,
|
||||||
|
@ -389,7 +373,7 @@ def send_backward_recv_backward(
|
||||||
"""
|
"""
|
||||||
_, output_tensor_grad = _communicate(
|
_, output_tensor_grad = _communicate(
|
||||||
object_send_prev=input_tensor_grad,
|
object_send_prev=input_tensor_grad,
|
||||||
recv_next=recv_next,
|
recv_next=output_grad_shape is not None,
|
||||||
recv_next_shape=output_grad_shape,
|
recv_next_shape=output_grad_shape,
|
||||||
prev_rank=prev_rank,
|
prev_rank=prev_rank,
|
||||||
next_rank=next_rank,
|
next_rank=next_rank,
|
||||||
|
@ -404,8 +388,6 @@ def send_forward_backward_recv_forward_backward(
|
||||||
input_tensor_grad,
|
input_tensor_grad,
|
||||||
input_tensor_shape,
|
input_tensor_shape,
|
||||||
output_grad_shape,
|
output_grad_shape,
|
||||||
recv_prev=True,
|
|
||||||
recv_next=True,
|
|
||||||
prev_rank=None,
|
prev_rank=None,
|
||||||
next_rank=None,
|
next_rank=None,
|
||||||
dtype=torch.float,
|
dtype=torch.float,
|
||||||
|
@ -430,8 +412,8 @@ def send_forward_backward_recv_forward_backward(
|
||||||
input_tensor, output_tensor_grad = _communicate(
|
input_tensor, output_tensor_grad = _communicate(
|
||||||
object_send_next=output_tensor,
|
object_send_next=output_tensor,
|
||||||
object_send_prev=input_tensor_grad,
|
object_send_prev=input_tensor_grad,
|
||||||
recv_prev=recv_prev,
|
recv_prev=input_tensor_shape is not None,
|
||||||
recv_next=recv_next,
|
recv_next=output_grad_shape is not None,
|
||||||
recv_prev_shape=input_tensor_shape,
|
recv_prev_shape=input_tensor_shape,
|
||||||
recv_next_shape=output_grad_shape,
|
recv_next_shape=output_grad_shape,
|
||||||
prev_rank=prev_rank,
|
prev_rank=prev_rank,
|
||||||
|
@ -440,3 +422,159 @@ def send_forward_backward_recv_forward_backward(
|
||||||
scatter_gather_tensors=scatter_gather_tensors,
|
scatter_gather_tensors=scatter_gather_tensors,
|
||||||
)
|
)
|
||||||
return input_tensor, output_tensor_grad
|
return input_tensor, output_tensor_grad
|
||||||
|
|
||||||
|
|
||||||
|
def send_forward_and_recv_next_forward_async(
|
||||||
|
output_tensor,
|
||||||
|
recv_prev_shape: Union[torch.Size, List[torch.Size]] = None,
|
||||||
|
dtype: torch.dtype = None,
|
||||||
|
scatter_gather_tensors=False,
|
||||||
|
):
|
||||||
|
"""send forward output to next rank and recv forward input from prev rank"""
|
||||||
|
|
||||||
|
reqs = []
|
||||||
|
tensor_recv_prev = None
|
||||||
|
|
||||||
|
# prepare send opreations
|
||||||
|
if output_tensor is not None:
|
||||||
|
next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)
|
||||||
|
|
||||||
|
output_tensor = process_object_to_send(output_tensor, scatter_gather_tensors)
|
||||||
|
|
||||||
|
if isinstance(output_tensor, torch.Tensor):
|
||||||
|
reqs.append(dist.P2POp(dist.isend, output_tensor, next_rank))
|
||||||
|
else:
|
||||||
|
for tensor_to_comm in output_tensor:
|
||||||
|
reqs.append(dist.P2POp(dist.isend, tensor_to_comm, next_rank))
|
||||||
|
|
||||||
|
# prepare receive opreations
|
||||||
|
if recv_prev_shape is not None:
|
||||||
|
prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE)
|
||||||
|
# create receive buffer
|
||||||
|
tensor_recv_prev, recv_prev_split = create_recv_buffer_with_shapes(
|
||||||
|
recv_prev_shape, dtype, scatter_gather_tensors
|
||||||
|
)
|
||||||
|
# generate async receive opterations
|
||||||
|
if isinstance(tensor_recv_prev, torch.Tensor):
|
||||||
|
reqs.append(dist.P2POp(dist.irecv, tensor_recv_prev, prev_rank))
|
||||||
|
else:
|
||||||
|
for tensor_to_comm in tensor_recv_prev:
|
||||||
|
reqs.append(dist.P2POp(dist.irecv, tensor_to_comm, prev_rank))
|
||||||
|
|
||||||
|
if len(reqs) > 0:
|
||||||
|
reqs = dist.batch_isend_irecv(reqs)
|
||||||
|
|
||||||
|
# return and do other things
|
||||||
|
yield
|
||||||
|
|
||||||
|
# check communication completed
|
||||||
|
for req in reqs:
|
||||||
|
req.wait()
|
||||||
|
# To protect against race condition when using batch_isend_irecv()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
# Process received data
|
||||||
|
if recv_prev_shape is not None and recv_prev_split:
|
||||||
|
if isinstance(tensor_recv_prev, torch.Tensor):
|
||||||
|
tensor_recv_prev = gather_split_1d_tensor(tensor_recv_prev).view(recv_prev_shape).requires_grad_()
|
||||||
|
else:
|
||||||
|
for index in range(len(tensor_recv_prev)):
|
||||||
|
tensor_recv_prev[index] = (
|
||||||
|
gather_split_1d_tensor(tensor_recv_prev[index]).view(recv_prev_shape[index]).requires_grad_()
|
||||||
|
)
|
||||||
|
|
||||||
|
yield tensor_recv_prev
|
||||||
|
|
||||||
|
|
||||||
|
def send_backward_and_recv_next_backward_async(
|
||||||
|
input_tensor,
|
||||||
|
recv_next_shape: Union[torch.Size, List[torch.Size]] = None,
|
||||||
|
dtype: torch.dtype = None,
|
||||||
|
scatter_gather_tensors=False,
|
||||||
|
):
|
||||||
|
reqs = []
|
||||||
|
tensor_recv_next = None
|
||||||
|
|
||||||
|
# prepare send opreations
|
||||||
|
if input_tensor is not None:
|
||||||
|
prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE)
|
||||||
|
|
||||||
|
input_tensor = process_object_to_send(input_tensor, scatter_gather_tensors)
|
||||||
|
|
||||||
|
if isinstance(input_tensor, torch.Tensor):
|
||||||
|
reqs.append(dist.P2POp(dist.isend, input_tensor, prev_rank))
|
||||||
|
else:
|
||||||
|
for tensor_to_comm in input_tensor:
|
||||||
|
reqs.append(dist.P2POp(dist.isend, tensor_to_comm, prev_rank))
|
||||||
|
|
||||||
|
# prepare receive opreations
|
||||||
|
if recv_next_shape is not None:
|
||||||
|
next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)
|
||||||
|
# create receive buffer
|
||||||
|
tensor_recv_next, recv_next_split = create_recv_buffer_with_shapes(
|
||||||
|
recv_next_shape, dtype, scatter_gather_tensors
|
||||||
|
)
|
||||||
|
# generate async receive opreations
|
||||||
|
if isinstance(tensor_recv_next, torch.Tensor):
|
||||||
|
reqs.append(dist.P2POp(dist.irecv, tensor_recv_next, next_rank))
|
||||||
|
else:
|
||||||
|
for tensor_to_comm in tensor_recv_next:
|
||||||
|
reqs.append(dist.P2POp(dist.irecv, tensor_to_comm, next_rank))
|
||||||
|
|
||||||
|
if len(reqs) > 0:
|
||||||
|
reqs = dist.batch_isend_irecv(reqs)
|
||||||
|
|
||||||
|
# return and do other things
|
||||||
|
yield
|
||||||
|
|
||||||
|
# check communication completed
|
||||||
|
for req in reqs:
|
||||||
|
req.wait()
|
||||||
|
# To protect against race condition when using batch_isend_irecv()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
# Process received data
|
||||||
|
if recv_next_shape is not None and recv_next_split:
|
||||||
|
if isinstance(tensor_recv_next, torch.Tensor):
|
||||||
|
tensor_recv_next = gather_split_1d_tensor(tensor_recv_next).view(recv_next_shape).requires_grad_()
|
||||||
|
else:
|
||||||
|
for index in range(len(tensor_recv_next)):
|
||||||
|
tensor_recv_next[index] = (
|
||||||
|
gather_split_1d_tensor(tensor_recv_next[index]).view(recv_next_shape[index]).requires_grad_()
|
||||||
|
)
|
||||||
|
|
||||||
|
yield tensor_recv_next
|
||||||
|
|
||||||
|
|
||||||
|
class AsynCommunicator:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
tensor_to_send: Union[torch.Tensor, List[torch.Tensor]],
|
||||||
|
recv_shape: Union[torch.Size, List[torch.Size]],
|
||||||
|
dtype: torch.dtype = None,
|
||||||
|
scatter_gather_tensors=False,
|
||||||
|
forward: bool = True,
|
||||||
|
) -> None:
|
||||||
|
self._need_receive = recv_shape is not None
|
||||||
|
|
||||||
|
if forward:
|
||||||
|
self._coroutine = send_forward_and_recv_next_forward_async(
|
||||||
|
tensor_to_send, recv_shape, dtype, scatter_gather_tensors
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self._coroutine = send_backward_and_recv_next_backward_async(
|
||||||
|
tensor_to_send, recv_shape, dtype, scatter_gather_tensors
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def need_receive(self) -> bool:
|
||||||
|
return self._need_receive
|
||||||
|
|
||||||
|
def start(self) -> None:
|
||||||
|
next(self._coroutine)
|
||||||
|
|
||||||
|
def wait_and_receive(self) -> Union[torch.Tensor, List[torch.Tensor]]:
|
||||||
|
received = next(self._coroutine)
|
||||||
|
self._coroutine.close()
|
||||||
|
|
||||||
|
return received
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/communication
|
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/communication
|
||||||
|
|
||||||
|
from functools import wraps
|
||||||
from typing import List, Tuple, Union
|
from typing import List, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
@ -19,7 +20,7 @@ def send_meta_helper(obj, next_rank, tensor_kwargs):
|
||||||
dist.send(send_shape, next_rank)
|
dist.send(send_shape, next_rank)
|
||||||
|
|
||||||
|
|
||||||
def send_obj_meta(obj, need_meta=True, next_rank=None) -> bool:
|
def send_obj_meta(obj, next_rank=None):
|
||||||
"""Sends obj meta information before sending a specific obj.
|
"""Sends obj meta information before sending a specific obj.
|
||||||
Since the recipient must know the shape of the obj in p2p communications,
|
Since the recipient must know the shape of the obj in p2p communications,
|
||||||
meta information of the obj should be sent before communications. This function
|
meta information of the obj should be sent before communications. This function
|
||||||
|
@ -33,7 +34,6 @@ def send_obj_meta(obj, need_meta=True, next_rank=None) -> bool:
|
||||||
Returns:
|
Returns:
|
||||||
bool: False
|
bool: False
|
||||||
"""
|
"""
|
||||||
if need_meta:
|
|
||||||
if next_rank is None:
|
if next_rank is None:
|
||||||
next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)
|
next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)
|
||||||
|
|
||||||
|
@ -48,8 +48,6 @@ def send_obj_meta(obj, need_meta=True, next_rank=None) -> bool:
|
||||||
for tensor_to_send in obj:
|
for tensor_to_send in obj:
|
||||||
send_meta_helper(tensor_to_send, next_rank, tensor_kwargs)
|
send_meta_helper(tensor_to_send, next_rank, tensor_kwargs)
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def recv_meta_helper(prev_rank, tensor_kwargs):
|
def recv_meta_helper(prev_rank, tensor_kwargs):
|
||||||
recv_ndims = torch.empty((), **tensor_kwargs)
|
recv_ndims = torch.empty((), **tensor_kwargs)
|
||||||
|
@ -59,7 +57,7 @@ def recv_meta_helper(prev_rank, tensor_kwargs):
|
||||||
return recv_shape
|
return recv_shape
|
||||||
|
|
||||||
|
|
||||||
def recv_obj_meta(obj_shape, prev_rank=None) -> torch.Size:
|
def recv_obj_meta(prev_rank=None) -> torch.Size:
|
||||||
"""Receives obj meta information before receiving a specific obj.
|
"""Receives obj meta information before receiving a specific obj.
|
||||||
Since the recipient must know the shape of the obj in p2p communications,
|
Since the recipient must know the shape of the obj in p2p communications,
|
||||||
meta information of the obj should be received before communications. This function
|
meta information of the obj should be received before communications. This function
|
||||||
|
@ -72,7 +70,6 @@ def recv_obj_meta(obj_shape, prev_rank=None) -> torch.Size:
|
||||||
Returns:
|
Returns:
|
||||||
Union[:class:`torch.Size`, List[:class:`torch.Size`]]: The shape of the obj to be received.
|
Union[:class:`torch.Size`, List[:class:`torch.Size`]]: The shape of the obj to be received.
|
||||||
"""
|
"""
|
||||||
if obj_shape is None:
|
|
||||||
if prev_rank is None:
|
if prev_rank is None:
|
||||||
prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE)
|
prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE)
|
||||||
|
|
||||||
|
|
|
@ -73,6 +73,17 @@ class NaiveAMPModel(nn.Module):
|
||||||
input_ = input_.float()
|
input_ = input_.float()
|
||||||
return input_
|
return input_
|
||||||
|
|
||||||
|
def convert_to_fp32(self, out):
|
||||||
|
"""Converts the output to fp32"""
|
||||||
|
if isinstance(out, Tensor):
|
||||||
|
out = self._convert_to_fp32(out)
|
||||||
|
elif isinstance(out, (tuple, list)):
|
||||||
|
out = [self._convert_to_fp32(val) for val in out]
|
||||||
|
elif isinstance(out, dict):
|
||||||
|
out = {key: self._convert_to_fp32(val) for key, val in out.items()}
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
def _reduce_module_buffer(self):
|
def _reduce_module_buffer(self):
|
||||||
"""
|
"""
|
||||||
All-reduces the buffers (e.g., running stats of batch normalization) across
|
All-reduces the buffers (e.g., running stats of batch normalization) across
|
||||||
|
@ -121,10 +132,5 @@ class NaiveAMPModel(nn.Module):
|
||||||
out = self.model(*args, **kwargs)
|
out = self.model(*args, **kwargs)
|
||||||
|
|
||||||
if self._output_to_fp32:
|
if self._output_to_fp32:
|
||||||
if isinstance(out, Tensor):
|
out = self.convert_to_fp32(out)
|
||||||
out = self._convert_to_fp32(out)
|
|
||||||
elif isinstance(out, (tuple, list)):
|
|
||||||
out = [self._convert_to_fp32(val) for val in out]
|
|
||||||
elif isinstance(out, dict):
|
|
||||||
out = {key: self._convert_to_fp32(val) for key, val in out.items()}
|
|
||||||
return out
|
return out
|
||||||
|
|
|
@ -1,5 +1,12 @@
|
||||||
from .base_scheduler import BaseScheduler
|
from .base_scheduler import BaseScheduler, SchedulerHook, SchedulerMetricHook
|
||||||
from .no_pipeline_scheduler import NonPipelineScheduler
|
from .no_pipeline_scheduler import NonPipelineScheduler
|
||||||
from .pipeline_scheduler import InterleavedPipelineScheduler, PipelineScheduler
|
from .pipeline_scheduler import InterleavedPipelineScheduler, PipelineScheduler
|
||||||
|
|
||||||
__all__ = ["BaseScheduler", "NonPipelineScheduler", "InterleavedPipelineScheduler", "PipelineScheduler"]
|
__all__ = [
|
||||||
|
"BaseScheduler",
|
||||||
|
"NonPipelineScheduler",
|
||||||
|
"InterleavedPipelineScheduler",
|
||||||
|
"PipelineScheduler",
|
||||||
|
"SchedulerHook",
|
||||||
|
"SchedulerMetricHook",
|
||||||
|
]
|
||||||
|
|
|
@ -4,11 +4,12 @@
|
||||||
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/engine
|
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/engine
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Callable, Iterable
|
from typing import Any, Callable, Iterable, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from internlm.core.engine import Engine
|
from internlm.core.engine import Engine
|
||||||
|
from internlm.utils.megatron_timers import megatron_timer as timer
|
||||||
|
|
||||||
|
|
||||||
class BaseScheduler(ABC):
|
class BaseScheduler(ABC):
|
||||||
|
@ -112,3 +113,85 @@ class BaseScheduler(ABC):
|
||||||
'(which is auto-converted to tuple), list, tuple, or dict, ' \
|
'(which is auto-converted to tuple), list, tuple, or dict, ' \
|
||||||
'but got {type(outputs)} (model outputs) and {type(labels)} (labels)"
|
'but got {type(outputs)} (model outputs) and {type(labels)} (labels)"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SchedulerHook(ABC):
|
||||||
|
"""
|
||||||
|
Scheduler Hook.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def before_forward(self, scheduler, inputs) -> None:
|
||||||
|
"""Actions before forward"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def after_forward(self, scheduler, outputs) -> None:
|
||||||
|
"""Actions after forward"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def before_criterion(self, scheduler, outputs, label) -> None:
|
||||||
|
"""Actions before criterion"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def after_criterion(self, scheduler, loss) -> None:
|
||||||
|
"""Actions after criterion"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def before_backward(self, scheduler, outputs, outputs_grad) -> None:
|
||||||
|
"""Actions before backward"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def after_backward(self, scheduler, inputs_grad) -> None:
|
||||||
|
"""Actions after backward"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def post_helper_func(self, scheduler, outputs, label) -> None:
|
||||||
|
"""A post helper function"""
|
||||||
|
|
||||||
|
|
||||||
|
class SchedulerMetricHook(SchedulerHook):
|
||||||
|
"""
|
||||||
|
Scheduler Metric Hook.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, metric: Optional[Callable] = None, skip: bool = False) -> None:
|
||||||
|
self._post_func = metric
|
||||||
|
self._skip = skip
|
||||||
|
|
||||||
|
if skip:
|
||||||
|
# init timer only.
|
||||||
|
timer("fwd")
|
||||||
|
timer("bwd")
|
||||||
|
timer("cal_loss")
|
||||||
|
timer("post_fn")
|
||||||
|
|
||||||
|
def before_forward(self, scheduler, inputs) -> None:
|
||||||
|
if not self._skip:
|
||||||
|
timer("fwd").start()
|
||||||
|
|
||||||
|
def after_forward(self, scheduler, outputs) -> None:
|
||||||
|
if not self._skip:
|
||||||
|
timer("fwd").stop()
|
||||||
|
|
||||||
|
def before_criterion(self, scheduler, outputs, label) -> None:
|
||||||
|
if not self._skip:
|
||||||
|
timer("cal_loss").start()
|
||||||
|
|
||||||
|
def after_criterion(self, scheduler, loss) -> None:
|
||||||
|
if not self._skip:
|
||||||
|
timer("cal_loss").stop()
|
||||||
|
|
||||||
|
def before_backward(self, scheduler, outputs, outputs_grad) -> None:
|
||||||
|
if not self._skip:
|
||||||
|
timer("bwd").start()
|
||||||
|
|
||||||
|
def after_backward(self, scheduler, inputs_grad) -> None:
|
||||||
|
if not self._skip:
|
||||||
|
timer("bwd").stop()
|
||||||
|
|
||||||
|
def post_helper_func(self, scheduler, outputs, label) -> None:
|
||||||
|
if not self._skip:
|
||||||
|
timer("post_fn").start()
|
||||||
|
if self._post_func is not None:
|
||||||
|
self._post_func(outputs, label)
|
||||||
|
timer("post_fn").stop()
|
||||||
|
|
|
@ -3,14 +3,14 @@
|
||||||
|
|
||||||
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/engine
|
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/engine
|
||||||
|
|
||||||
from typing import Any, Callable, Iterable
|
from typing import Any, Callable, Iterable, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from internlm.core.engine import Engine
|
from internlm.core.engine import Engine
|
||||||
from internlm.utils.common import conditional_context
|
from internlm.utils.common import conditional_context
|
||||||
|
|
||||||
from .base_scheduler import BaseScheduler
|
from .base_scheduler import BaseScheduler, SchedulerHook
|
||||||
|
|
||||||
|
|
||||||
class NonPipelineScheduler(BaseScheduler):
|
class NonPipelineScheduler(BaseScheduler):
|
||||||
|
@ -34,10 +34,17 @@ class NonPipelineScheduler(BaseScheduler):
|
||||||
return data, label
|
return data, label
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, data_process_func: Callable = None, gradient_accumulation_size: int = 1):
|
def __init__(
|
||||||
|
self,
|
||||||
|
data_process_func: Callable = None,
|
||||||
|
gradient_accumulation_size: int = 1,
|
||||||
|
scheduler_hooks: Optional[List[SchedulerHook]] = None,
|
||||||
|
):
|
||||||
self._grad_accum_size = gradient_accumulation_size
|
self._grad_accum_size = gradient_accumulation_size
|
||||||
self._grad_accum_offset = 0
|
self._grad_accum_offset = 0
|
||||||
|
|
||||||
|
self._hooks = scheduler_hooks
|
||||||
|
|
||||||
super().__init__(data_process_func)
|
super().__init__(data_process_func)
|
||||||
|
|
||||||
def pre_processing(self, engine: Engine):
|
def pre_processing(self, engine: Engine):
|
||||||
|
@ -48,6 +55,10 @@ class NonPipelineScheduler(BaseScheduler):
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def _call_hooks(self, func_name: str, *args, **kwargs) -> None:
|
||||||
|
for hook in self._hooks:
|
||||||
|
getattr(hook, func_name)(self, *args, **kwargs)
|
||||||
|
|
||||||
def _load_accum_batch(self, data: Any, label: Any):
|
def _load_accum_batch(self, data: Any, label: Any):
|
||||||
"""Loads a batch of data and label for gradient accumulation.
|
"""Loads a batch of data and label for gradient accumulation.
|
||||||
|
|
||||||
|
@ -77,7 +88,6 @@ class NonPipelineScheduler(BaseScheduler):
|
||||||
forward_only: bool = False,
|
forward_only: bool = False,
|
||||||
return_loss: bool = True,
|
return_loss: bool = True,
|
||||||
scale_loss: int = 1,
|
scale_loss: int = 1,
|
||||||
post_fn: Callable = None,
|
|
||||||
):
|
):
|
||||||
"""Trains one batch of data.
|
"""Trains one batch of data.
|
||||||
|
|
||||||
|
@ -89,23 +99,27 @@ class NonPipelineScheduler(BaseScheduler):
|
||||||
be executed.
|
be executed.
|
||||||
return_loss (bool, optional): Loss will be returned if True.
|
return_loss (bool, optional): Loss will be returned if True.
|
||||||
scale_loss (int, optional): The scale factor for the loss.
|
scale_loss (int, optional): The scale factor for the loss.
|
||||||
post_fn (Callable, optional): Call back function after executing data forward output.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# forward
|
# forward
|
||||||
with conditional_context(torch.no_grad(), enable=forward_only):
|
with conditional_context(torch.no_grad(), enable=forward_only):
|
||||||
|
self._call_hooks("before_forward", data)
|
||||||
output = self._call_engine(engine, data)
|
output = self._call_engine(engine, data)
|
||||||
|
self._call_hooks("after_forward", output)
|
||||||
|
|
||||||
if post_fn is not None:
|
self._call_hooks("post_helper_func", output, label)
|
||||||
post_fn(output, label)
|
|
||||||
|
|
||||||
if return_loss:
|
if return_loss:
|
||||||
|
self._call_hooks("before_criterion", output, label)
|
||||||
loss = self._call_engine_criterion(engine, output, label)
|
loss = self._call_engine_criterion(engine, output, label)
|
||||||
|
self._call_hooks("after_criterion", loss)
|
||||||
loss /= scale_loss
|
loss /= scale_loss
|
||||||
|
|
||||||
# backward
|
# backward
|
||||||
if not forward_only:
|
if not forward_only:
|
||||||
|
self._call_hooks("before_backward", None, None)
|
||||||
engine.backward(loss)
|
engine.backward(loss)
|
||||||
|
self._call_hooks("after_backward", None)
|
||||||
|
|
||||||
if not return_loss:
|
if not return_loss:
|
||||||
loss = None
|
loss = None
|
||||||
|
@ -119,7 +133,6 @@ class NonPipelineScheduler(BaseScheduler):
|
||||||
forward_only: bool = False,
|
forward_only: bool = False,
|
||||||
return_loss: bool = True,
|
return_loss: bool = True,
|
||||||
return_output_label: bool = True,
|
return_output_label: bool = True,
|
||||||
post_fn: Callable = None,
|
|
||||||
):
|
):
|
||||||
"""The process function that loads a batch of dataset and feeds it to the model.
|
"""The process function that loads a batch of dataset and feeds it to the model.
|
||||||
The returned labels and loss will None if :attr:`return_loss` is False.
|
The returned labels and loss will None if :attr:`return_loss` is False.
|
||||||
|
@ -131,7 +144,6 @@ class NonPipelineScheduler(BaseScheduler):
|
||||||
If True, the model is run for the forward pass, else back propagation will be executed.
|
If True, the model is run for the forward pass, else back propagation will be executed.
|
||||||
return_loss (bool, optional): Loss will be returned if True.
|
return_loss (bool, optional): Loss will be returned if True.
|
||||||
return_output_label (bool, optional): Output and label will be returned if True.
|
return_output_label (bool, optional): Output and label will be returned if True.
|
||||||
post_fn (Callable, optional): Call back function after executing data forward output.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None.
|
Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None.
|
||||||
|
@ -165,7 +177,7 @@ class NonPipelineScheduler(BaseScheduler):
|
||||||
_data, _label = self._load_accum_batch(data, label)
|
_data, _label = self._load_accum_batch(data, label)
|
||||||
|
|
||||||
_output, _loss = self._train_one_batch(
|
_output, _loss = self._train_one_batch(
|
||||||
_data, _label, engine, forward_only, return_loss, self._grad_accum_size, post_fn
|
_data, _label, engine, forward_only, return_loss, self._grad_accum_size
|
||||||
)
|
)
|
||||||
|
|
||||||
if return_loss:
|
if return_loss:
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -3,7 +3,7 @@
|
||||||
|
|
||||||
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/initialize
|
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/initialize
|
||||||
|
|
||||||
from typing import Callable, Iterable, Optional, Tuple
|
from typing import Callable, Iterable, List, Optional, Tuple
|
||||||
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn.modules.loss import _Loss
|
from torch.nn.modules.loss import _Loss
|
||||||
|
@ -15,12 +15,13 @@ from internlm.core.context import ParallelMode
|
||||||
from internlm.core.context import global_context as gpc
|
from internlm.core.context import global_context as gpc
|
||||||
from internlm.core.engine import Engine
|
from internlm.core.engine import Engine
|
||||||
from internlm.core.gradient_handler import PipelineSharedModuleGradientHandler
|
from internlm.core.gradient_handler import PipelineSharedModuleGradientHandler
|
||||||
from internlm.core.scheduler.no_pipeline_scheduler import NonPipelineScheduler
|
from internlm.core.scheduler import (
|
||||||
from internlm.core.scheduler.pipeline_scheduler import (
|
|
||||||
InterleavedPipelineScheduler,
|
InterleavedPipelineScheduler,
|
||||||
|
NonPipelineScheduler,
|
||||||
PipelineScheduler,
|
PipelineScheduler,
|
||||||
get_tensor_shape,
|
SchedulerHook,
|
||||||
)
|
)
|
||||||
|
from internlm.core.scheduler.pipeline_scheduler import get_tensor_shape
|
||||||
from internlm.core.trainer import Trainer
|
from internlm.core.trainer import Trainer
|
||||||
from internlm.data.utils import unpack_data
|
from internlm.data.utils import unpack_data
|
||||||
from internlm.solver.beta2_scheduler import Beta2Scheduler
|
from internlm.solver.beta2_scheduler import Beta2Scheduler
|
||||||
|
@ -36,6 +37,7 @@ def initialize_trainer(
|
||||||
test_dataloader: Optional[Iterable] = None,
|
test_dataloader: Optional[Iterable] = None,
|
||||||
lr_scheduler: Optional[_LRScheduler] = None,
|
lr_scheduler: Optional[_LRScheduler] = None,
|
||||||
beta2_scheduler: Optional[Beta2Scheduler] = None,
|
beta2_scheduler: Optional[Beta2Scheduler] = None,
|
||||||
|
scheduler_hooks: Optional[List[SchedulerHook]] = None,
|
||||||
) -> Tuple[Trainer, DataLoader, DataLoader, _LRScheduler]:
|
) -> Tuple[Trainer, DataLoader, DataLoader, _LRScheduler]:
|
||||||
"""Core function to wrap the essential training components with our functionality based on the config which is
|
"""Core function to wrap the essential training components with our functionality based on the config which is
|
||||||
loaded into gpc.config.
|
loaded into gpc.config.
|
||||||
|
@ -92,12 +94,16 @@ def initialize_trainer(
|
||||||
if use_interleaved:
|
if use_interleaved:
|
||||||
if isinstance(model, nn.Sequential):
|
if isinstance(model, nn.Sequential):
|
||||||
model = nn.ModuleList([model])
|
model = nn.ModuleList([model])
|
||||||
|
|
||||||
|
communication_overlap = gpc.config.parallel["pipeline"].get("interleaved_overlap", False)
|
||||||
scheduler = InterleavedPipelineScheduler(
|
scheduler = InterleavedPipelineScheduler(
|
||||||
num_microbatches=gpc.config.NUM_MICRO_BATCHES,
|
num_microbatches=gpc.config.NUM_MICRO_BATCHES,
|
||||||
num_model_chunks=gpc.config.model.num_chunks,
|
num_chunks=gpc.config.model.num_chunks,
|
||||||
dtype=gpc.config.model["dtype"],
|
dtype=gpc.config.model["dtype"],
|
||||||
tensor_shape=tensor_shape,
|
tensor_shape=tensor_shape,
|
||||||
scatter_gather_tensors=scatter_gather,
|
scatter_gather_tensors=scatter_gather,
|
||||||
|
scheduler_hooks=scheduler_hooks,
|
||||||
|
communication_overlap=communication_overlap,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
scheduler = PipelineScheduler(
|
scheduler = PipelineScheduler(
|
||||||
|
@ -106,10 +112,13 @@ def initialize_trainer(
|
||||||
dtype=gpc.config.model["dtype"],
|
dtype=gpc.config.model["dtype"],
|
||||||
tensor_shape=tensor_shape,
|
tensor_shape=tensor_shape,
|
||||||
scatter_gather_tensors=scatter_gather,
|
scatter_gather_tensors=scatter_gather,
|
||||||
|
scheduler_hooks=scheduler_hooks,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
scheduler = NonPipelineScheduler(
|
scheduler = NonPipelineScheduler(
|
||||||
data_process_func=data_fn, gradient_accumulation_size=gpc.config.data.gradient_accumulation
|
data_process_func=data_fn,
|
||||||
|
gradient_accumulation_size=gpc.config.data.gradient_accumulation,
|
||||||
|
scheduler_hooks=scheduler_hooks,
|
||||||
)
|
)
|
||||||
|
|
||||||
# initialize engine for trainer
|
# initialize engine for trainer
|
||||||
|
|
|
@ -7,40 +7,47 @@ from tqdm import tqdm
|
||||||
from internlm.core.context import ParallelMode
|
from internlm.core.context import ParallelMode
|
||||||
from internlm.core.context import global_context as gpc
|
from internlm.core.context import global_context as gpc
|
||||||
from internlm.model.metrics import AccPerplex
|
from internlm.model.metrics import AccPerplex
|
||||||
|
from internlm.core.scheduler import SchedulerMetricHook
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def switch_evaluation_no_pipeline_scheduler(trainer, grad_accum_size, grad_accum_batch_size):
|
def switch_evaluation_no_pipeline_scheduler(trainer, grad_accum_size, grad_accum_batch_size, metric_hook_list):
|
||||||
if not gpc.is_using_pp():
|
if not gpc.is_using_pp():
|
||||||
prev_data_process_func = trainer.schedule.data_process_func
|
prev_data_process_func = trainer.schedule.data_process_func
|
||||||
prev_grad_accum_size = trainer.schedule._grad_accum_size
|
prev_grad_accum_size = trainer.schedule._grad_accum_size
|
||||||
prev_grad_accum_batch_size = trainer.schedule._grad_accum_batch_size
|
prev_grad_accum_batch_size = trainer.schedule._grad_accum_batch_size
|
||||||
|
prev_metric_hooks = trainer.schedule._hooks
|
||||||
try:
|
try:
|
||||||
trainer.schedule.data_process_func = None
|
trainer.schedule.data_process_func = None
|
||||||
trainer.schedule._grad_accum_size = grad_accum_size
|
trainer.schedule._grad_accum_size = grad_accum_size
|
||||||
trainer.schedule._grad_accum_batch_size = grad_accum_batch_size
|
trainer.schedule._grad_accum_batch_size = grad_accum_batch_size
|
||||||
|
trainer.schedule._hooks = metric_hook_list
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
trainer.schedule.data_process_func = prev_data_process_func
|
trainer.schedule.data_process_func = prev_data_process_func
|
||||||
trainer.schedule._grad_accum_size = prev_grad_accum_size
|
trainer.schedule._grad_accum_size = prev_grad_accum_size
|
||||||
trainer.schedule._grad_accum_batch_size = prev_grad_accum_batch_size
|
trainer.schedule._grad_accum_batch_size = prev_grad_accum_batch_size
|
||||||
|
trainer.schedule._hooks = prev_metric_hooks
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def switch_evaluation_pipeline_scheduler(trainer, num_microbatches, tensor_shape):
|
def switch_evaluation_pipeline_scheduler(trainer, num_microbatches, tensor_shape, metric_hook_list):
|
||||||
if gpc.is_using_pp():
|
if gpc.is_using_pp():
|
||||||
pre_data_process_func = trainer.schedule.data_process_func
|
pre_data_process_func = trainer.schedule.data_process_func
|
||||||
prev_num_microbatches = trainer.schedule.num_microbatches
|
prev_num_microbatches = trainer.schedule.num_microbatches
|
||||||
prev_tensor_shape = trainer.schedule.tensor_shape
|
prev_tensor_shape = trainer.schedule.tensor_shape
|
||||||
|
prev_metric_hooks = trainer.schedule._hooks
|
||||||
try:
|
try:
|
||||||
trainer.schedule.data_process_func = None
|
trainer.schedule.data_process_func = None
|
||||||
trainer.schedule.num_microbatches = num_microbatches
|
trainer.schedule.num_microbatches = num_microbatches
|
||||||
trainer.schedule.tensor_shape = tensor_shape
|
trainer.schedule.tensor_shape = tensor_shape
|
||||||
|
trainer.schedule._hooks = metric_hook_list
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
trainer.schedule.data_process_func = pre_data_process_func
|
trainer.schedule.data_process_func = pre_data_process_func
|
||||||
trainer.schedule.num_microbatches = prev_num_microbatches
|
trainer.schedule.num_microbatches = prev_num_microbatches
|
||||||
trainer.schedule.tensor_shape = prev_tensor_shape
|
trainer.schedule.tensor_shape = prev_tensor_shape
|
||||||
|
trainer.schedule._hooks = prev_metric_hooks
|
||||||
|
|
||||||
|
|
||||||
def evaluate_on_val_dls(
|
def evaluate_on_val_dls(
|
||||||
|
@ -49,7 +56,6 @@ def evaluate_on_val_dls(
|
||||||
writer,
|
writer,
|
||||||
logger,
|
logger,
|
||||||
step_count,
|
step_count,
|
||||||
tokenizer=None,
|
|
||||||
update_panel: bool = False,
|
update_panel: bool = False,
|
||||||
):
|
):
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
@ -66,8 +72,9 @@ def evaluate_on_val_dls(
|
||||||
device=torch.cuda.current_device(),
|
device=torch.cuda.current_device(),
|
||||||
tp_pg=gpc.get_group(ParallelMode.TENSOR),
|
tp_pg=gpc.get_group(ParallelMode.TENSOR),
|
||||||
dp_pg=gpc.get_group(ParallelMode.DATA),
|
dp_pg=gpc.get_group(ParallelMode.DATA),
|
||||||
tokenizer=tokenizer,
|
|
||||||
)
|
)
|
||||||
|
val_sche_metric_hook = SchedulerMetricHook(metric=val_metric)
|
||||||
|
|
||||||
val_loss = 0
|
val_loss = 0
|
||||||
val_idx = -1
|
val_idx = -1
|
||||||
for val_idx, batch in tqdm(
|
for val_idx, batch in tqdm(
|
||||||
|
@ -88,10 +95,13 @@ def evaluate_on_val_dls(
|
||||||
)
|
)
|
||||||
|
|
||||||
with switch_evaluation_pipeline_scheduler(
|
with switch_evaluation_pipeline_scheduler(
|
||||||
trainer=trainer, num_microbatches=num_microbatches, tensor_shape=tensor_shape
|
trainer=trainer,
|
||||||
|
num_microbatches=num_microbatches,
|
||||||
|
tensor_shape=tensor_shape,
|
||||||
|
metric_hook_list=[val_sche_metric_hook],
|
||||||
):
|
):
|
||||||
_, _, loss = trainer.execute_schedule(
|
_, _, loss = trainer.execute_schedule(
|
||||||
batch, forward_only=True, return_loss=True, return_output_label=False, post_fn=val_metric
|
batch, forward_only=True, return_loss=True, return_output_label=False
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
total_val_bsz = len(batch[1])
|
total_val_bsz = len(batch[1])
|
||||||
|
@ -100,38 +110,42 @@ def evaluate_on_val_dls(
|
||||||
grad_accum_batch_size = data_cfg.micro_bsz
|
grad_accum_batch_size = data_cfg.micro_bsz
|
||||||
|
|
||||||
with switch_evaluation_no_pipeline_scheduler(
|
with switch_evaluation_no_pipeline_scheduler(
|
||||||
trainer=trainer, grad_accum_size=grad_accum_size, grad_accum_batch_size=grad_accum_batch_size
|
trainer=trainer,
|
||||||
|
grad_accum_size=grad_accum_size,
|
||||||
|
grad_accum_batch_size=grad_accum_batch_size,
|
||||||
|
metric_hook_list=[val_sche_metric_hook],
|
||||||
):
|
):
|
||||||
_, _, loss = trainer.execute_schedule(
|
_, _, loss = trainer.execute_schedule(
|
||||||
batch, forward_only=True, return_loss=True, return_output_label=False, post_fn=val_metric
|
batch, forward_only=True, return_loss=True, return_output_label=False
|
||||||
)
|
)
|
||||||
if verbose:
|
if verbose:
|
||||||
val_loss += loss.item()
|
val_loss += loss.item()
|
||||||
|
|
||||||
assert val_idx != -1
|
assert val_idx != -1
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
val_res = val_metric.get_metric()
|
|
||||||
|
|
||||||
|
val_res = val_metric.get_metric()
|
||||||
if verbose and len(val_dl) != 0:
|
if verbose and len(val_dl) != 0:
|
||||||
val_loss = val_loss / (val_idx + 1 + 1e-6)
|
val_loss = val_loss / (val_idx + 1 + 1e-6)
|
||||||
infos = {
|
infos = {
|
||||||
|
"step": step_count,
|
||||||
f"val/{val_name}_loss": val_loss,
|
f"val/{val_name}_loss": val_loss,
|
||||||
f"val/{val_name}_acc": val_res["acc"],
|
f"val/{val_name}_acc": val_res["acc"],
|
||||||
f"val/{val_name}_plex": val_res["perplexity"],
|
f"val/{val_name}_plex": val_res["perplexity"],
|
||||||
}
|
}
|
||||||
val_metric = {
|
|
||||||
|
for key, value in infos.items():
|
||||||
|
writer.add_scalar(key=key, value=value, step=step_count)
|
||||||
|
|
||||||
|
if update_panel:
|
||||||
|
logger.info(
|
||||||
|
f"Validation on {val_name}: " + " ".join([f"{key}={value}" for key, value in infos.items()]),
|
||||||
|
extra={
|
||||||
"step": step_count,
|
"step": step_count,
|
||||||
"val_loss": val_loss,
|
"val_loss": val_loss,
|
||||||
"val_acc": val_res["acc"],
|
"val_acc": val_res["acc"],
|
||||||
"val_perplexity": val_res["perplexity"],
|
"val_perplexity": val_res["perplexity"],
|
||||||
}
|
},
|
||||||
for key, value in infos.items():
|
|
||||||
writer.add_scalar(key=key, value=value, step=step_count)
|
|
||||||
infos["step"] = step_count
|
|
||||||
if update_panel:
|
|
||||||
logger.info(
|
|
||||||
f"Validation on {val_name}: " + " ".join([f"{key}={value}" for key, value in infos.items()]),
|
|
||||||
extra=val_metric,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.info(
|
logger.info(
|
||||||
|
|
52
train.py
52
train.py
|
@ -16,6 +16,7 @@ import internlm
|
||||||
from internlm.core.context import ParallelMode
|
from internlm.core.context import ParallelMode
|
||||||
from internlm.core.context import global_context as gpc
|
from internlm.core.context import global_context as gpc
|
||||||
from internlm.core.naive_amp import NaiveAMPModel
|
from internlm.core.naive_amp import NaiveAMPModel
|
||||||
|
from internlm.core.scheduler import SchedulerMetricHook
|
||||||
from internlm.core.trainer import TrainState
|
from internlm.core.trainer import TrainState
|
||||||
from internlm.data.batch_sampler import StaticBatchSampler, get_dpsampler_dataloader
|
from internlm.data.batch_sampler import StaticBatchSampler, get_dpsampler_dataloader
|
||||||
from internlm.data.collaters import jsonl_ds_collate_fn, packed_collate_fn
|
from internlm.data.collaters import jsonl_ds_collate_fn, packed_collate_fn
|
||||||
|
@ -109,6 +110,19 @@ def initialize_model():
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model = MODEL_INITIALIZER.get_module(module_name=gpc.config.model_type)(**(gpc.config.model))
|
model = MODEL_INITIALIZER.get_module(module_name=gpc.config.model_type)(**(gpc.config.model))
|
||||||
|
if isinstance(model, nn.ModuleList):
|
||||||
|
model = nn.ModuleList(
|
||||||
|
[
|
||||||
|
NaiveAMPModel(
|
||||||
|
model=_m,
|
||||||
|
output_to_fp32=False, # manually controlled by interleaved pipleline scheduler
|
||||||
|
dtype=gpc.config.model.get("dtype", torch.half),
|
||||||
|
sync_buffer=False,
|
||||||
|
)
|
||||||
|
for _m in model
|
||||||
|
]
|
||||||
|
)
|
||||||
|
else:
|
||||||
model = NaiveAMPModel(
|
model = NaiveAMPModel(
|
||||||
model=model,
|
model=model,
|
||||||
output_to_fp32=is_no_pp_or_last_stage(),
|
output_to_fp32=is_no_pp_or_last_stage(),
|
||||||
|
@ -500,19 +514,6 @@ def main(args):
|
||||||
if load_optimizer:
|
if load_optimizer:
|
||||||
load_optimizer_checkpoint(load_resume_ckpt_folder, optimizer)
|
load_optimizer_checkpoint(load_resume_ckpt_folder, optimizer)
|
||||||
|
|
||||||
# initialize trainer
|
|
||||||
trainer, train_dl, _, _ = internlm.initialize_trainer(
|
|
||||||
model=model,
|
|
||||||
optimizer=optimizer,
|
|
||||||
criterion=criterion,
|
|
||||||
train_dataloader=train_dl,
|
|
||||||
lr_scheduler=lr_scheduler,
|
|
||||||
beta2_scheduler=beta2_scheduler,
|
|
||||||
)
|
|
||||||
|
|
||||||
# initialize the batch skipper
|
|
||||||
batch_skipper = BatchSkipper(skip_batches)
|
|
||||||
|
|
||||||
# initialize metric for calculating accuracy and perplexity
|
# initialize metric for calculating accuracy and perplexity
|
||||||
metric = AccPerplex(
|
metric = AccPerplex(
|
||||||
device=torch.cuda.current_device(),
|
device=torch.cuda.current_device(),
|
||||||
|
@ -521,6 +522,27 @@ def main(args):
|
||||||
dataset_types=dataset_types,
|
dataset_types=dataset_types,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# initialize trainer
|
||||||
|
scheduler_hooks = [
|
||||||
|
SchedulerMetricHook(
|
||||||
|
metric=metric,
|
||||||
|
skip=gpc.is_using_pp() and gpc.config.parallel["pipeline"].get("interleaved_overlap", False),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
trainer, train_dl, _, _ = internlm.initialize_trainer(
|
||||||
|
model=model,
|
||||||
|
optimizer=optimizer,
|
||||||
|
criterion=criterion,
|
||||||
|
train_dataloader=train_dl,
|
||||||
|
lr_scheduler=lr_scheduler,
|
||||||
|
beta2_scheduler=beta2_scheduler,
|
||||||
|
scheduler_hooks=scheduler_hooks,
|
||||||
|
)
|
||||||
|
|
||||||
|
# initialize the batch skipper
|
||||||
|
batch_skipper = BatchSkipper(skip_batches)
|
||||||
|
|
||||||
trainer.train()
|
trainer.train()
|
||||||
|
|
||||||
# transfer the train data loader into train data iterator
|
# transfer the train data loader into train data iterator
|
||||||
|
@ -558,9 +580,7 @@ def main(args):
|
||||||
|
|
||||||
# do forward and backward
|
# do forward and backward
|
||||||
timer("fwd-bwd").start()
|
timer("fwd-bwd").start()
|
||||||
_, _, loss = trainer.execute_schedule(
|
_, _, loss = trainer.execute_schedule(batch, forward_only=False, return_loss=True, return_output_label=False)
|
||||||
batch, forward_only=False, return_loss=True, return_output_label=False, post_fn=metric
|
|
||||||
)
|
|
||||||
timer("fwd-bwd").stop()
|
timer("fwd-bwd").stop()
|
||||||
|
|
||||||
# update parameters, and returns (success_update, grad_norm)
|
# update parameters, and returns (success_update, grad_norm)
|
||||||
|
|
Loading…
Reference in New Issue