refactor(scheduler): rewrite pipeline scheduler (#138)

* refactor(scheduler): rewrite pipeline scheduler

* fix(*): fix pipeline scheduler bugs

* fix(*): fix merge bug

* feat(*): update codes with todo tag

* feat(*): add comments

* feat(internlm/core/scheduler): update recv_prev/next logic

* feat(utils/evaluation.py): update sche metric hook for valid

---------

Co-authored-by: huangting.p <huangting@sensetime.com>
pull/155/head
cx 2023-08-03 11:48:12 +08:00 committed by GitHub
parent d67be17f96
commit 0268d8eda1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 1412 additions and 712 deletions

View File

@ -117,6 +117,7 @@ model = dict(
norm_type="rmsnorm", norm_type="rmsnorm",
layer_norm_epsilon=1e-5, layer_norm_epsilon=1e-5,
use_flash_attn=True, use_flash_attn=True,
num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used.
) )
""" """
zero1 parallel: zero1 parallel:
@ -125,12 +126,14 @@ zero1 parallel:
2. if zero1 == 1, zero is not used, and all dp groups retain the full amount of model parameters. 2. if zero1 == 1, zero is not used, and all dp groups retain the full amount of model parameters.
3. zero1 > 1 and zero1 <= dp world size, the world size of zero is a subset of dp world size. 3. zero1 > 1 and zero1 <= dp world size, the world size of zero is a subset of dp world size.
For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8. For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8.
pipeline parallel: pipeline parallel size. pipeline parallel (dict):
1. size: int, the size of pipeline parallel.
2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler.
tensor parallel: tensor parallel size, usually the number of GPUs per node. tensor parallel: tensor parallel size, usually the number of GPUs per node.
""" """
parallel = dict( parallel = dict(
zero1=8, zero1=8,
pipeline=2, pipeline=dict(size=1, interleaved_overlap=True),
) )
cudnn_deterministic = False cudnn_deterministic = False

View File

@ -1,10 +1,13 @@
from .p2p import ( from .p2p import (
AsynCommunicator,
recv_backward, recv_backward,
recv_forward, recv_forward,
send_backward, send_backward,
send_backward_and_recv_next_backward_async,
send_backward_recv_backward, send_backward_recv_backward,
send_backward_recv_forward, send_backward_recv_forward,
send_forward, send_forward,
send_forward_and_recv_next_forward_async,
send_forward_backward_recv_forward_backward, send_forward_backward_recv_forward_backward,
send_forward_recv_backward, send_forward_recv_backward,
send_forward_recv_forward, send_forward_recv_forward,
@ -23,4 +26,7 @@ __all__ = [
"recv_forward", "recv_forward",
"send_obj_meta", "send_obj_meta",
"recv_obj_meta", "recv_obj_meta",
"send_backward_and_recv_next_backward_async",
"send_forward_and_recv_next_forward_async",
"AsynCommunicator",
] ]

View File

@ -207,9 +207,6 @@ def recv_forward(
Returns: Returns:
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input tensor or input tensor list. Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input tensor or input tensor list.
""" """
if gpc.is_pipeline_first_stage():
input_tensor = None
else:
input_tensor, _ = _communicate( input_tensor, _ = _communicate(
recv_prev=True, recv_prev=True,
recv_prev_shape=input_tensor_shape, recv_prev_shape=input_tensor_shape,
@ -233,9 +230,6 @@ def recv_backward(
Returns: Returns:
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input gradient tensor or gradident tensor list. Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input gradient tensor or gradident tensor list.
""" """
if gpc.is_pipeline_last_stage():
output_tensor_grad = None
else:
_, output_tensor_grad = _communicate( _, output_tensor_grad = _communicate(
recv_next=True, recv_next=True,
recv_next_shape=output_grad_shape, recv_next_shape=output_grad_shape,
@ -253,7 +247,6 @@ def send_forward(output_tensor, next_rank=None, scatter_gather_tensors=False) ->
output_tensor (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent. output_tensor (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent.
next_rank (int, optional): The rank of the recipient of the tensor. next_rank (int, optional): The rank of the recipient of the tensor.
""" """
if not gpc.is_pipeline_last_stage():
_communicate(object_send_next=output_tensor, next_rank=next_rank, scatter_gather_tensors=scatter_gather_tensors) _communicate(object_send_next=output_tensor, next_rank=next_rank, scatter_gather_tensors=scatter_gather_tensors)
@ -264,14 +257,12 @@ def send_backward(input_tensor_grad, prev_rank=None, scatter_gather_tensors=Fals
input_tensor_grad (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent input_tensor_grad (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent
prev_rank (int, optional): The rank of the recipient of the tensor prev_rank (int, optional): The rank of the recipient of the tensor
""" """
if not gpc.is_pipeline_first_stage():
_communicate( _communicate(object_send_prev=input_tensor_grad, prev_rank=prev_rank, scatter_gather_tensors=scatter_gather_tensors)
object_send_prev=input_tensor_grad, prev_rank=prev_rank, scatter_gather_tensors=scatter_gather_tensors
)
def send_forward_recv_backward( def send_forward_recv_backward(
output_tensor, output_grad_shape, recv_next=True, next_rank=None, dtype=torch.float, scatter_gather_tensors=False output_tensor, output_grad_shape, next_rank=None, dtype=torch.float, scatter_gather_tensors=False
) -> Union[torch.Tensor, List[torch.Tensor]]: ) -> Union[torch.Tensor, List[torch.Tensor]]:
"""Batched communication operation. Sends the input tensor to the """Batched communication operation. Sends the input tensor to the
next stage in pipeline, while receives the gradient tensor from the next stage in pipeline, while receives the gradient tensor from the
@ -285,24 +276,21 @@ def send_forward_recv_backward(
Returns: Returns:
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input gradient tensor. Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input gradient tensor.
""" """
if gpc.is_pipeline_last_stage():
output_tensor_grad = None
else:
_, output_tensor_grad = _communicate( _, output_tensor_grad = _communicate(
object_send_next=output_tensor, object_send_next=output_tensor,
recv_next=recv_next, recv_next=output_grad_shape is not None,
recv_next_shape=output_grad_shape, recv_next_shape=output_grad_shape,
next_rank=next_rank, next_rank=next_rank,
dtype=dtype, dtype=dtype,
scatter_gather_tensors=scatter_gather_tensors, scatter_gather_tensors=scatter_gather_tensors,
) )
return output_tensor_grad return output_tensor_grad
def send_backward_recv_forward( def send_backward_recv_forward(
input_tensor_grad, input_tensor_grad,
input_tensor_shape, input_tensor_shape,
recv_prev=True,
prev_rank=None, prev_rank=None,
dtype=torch.float, dtype=torch.float,
scatter_gather_tensors=False, scatter_gather_tensors=False,
@ -319,24 +307,21 @@ def send_backward_recv_forward(
Returns: Returns:
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input tensor. Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input tensor.
""" """
if gpc.is_pipeline_first_stage():
input_tensor = None
else:
input_tensor, _ = _communicate( input_tensor, _ = _communicate(
object_send_prev=input_tensor_grad, object_send_prev=input_tensor_grad,
recv_prev=recv_prev, recv_prev=input_tensor_shape is not None,
recv_prev_shape=input_tensor_shape, recv_prev_shape=input_tensor_shape,
prev_rank=prev_rank, prev_rank=prev_rank,
dtype=dtype, dtype=dtype,
scatter_gather_tensors=scatter_gather_tensors, scatter_gather_tensors=scatter_gather_tensors,
) )
return input_tensor return input_tensor
def send_forward_recv_forward( def send_forward_recv_forward(
output_tensor, output_tensor,
input_tensor_shape, input_tensor_shape,
recv_prev=True,
prev_rank=None, prev_rank=None,
next_rank=None, next_rank=None,
dtype=torch.float, dtype=torch.float,
@ -356,7 +341,7 @@ def send_forward_recv_forward(
""" """
input_tensor, _ = _communicate( input_tensor, _ = _communicate(
object_send_next=output_tensor, object_send_next=output_tensor,
recv_prev=recv_prev, recv_prev=input_tensor_shape is not None,
recv_prev_shape=input_tensor_shape, recv_prev_shape=input_tensor_shape,
prev_rank=prev_rank, prev_rank=prev_rank,
next_rank=next_rank, next_rank=next_rank,
@ -369,7 +354,6 @@ def send_forward_recv_forward(
def send_backward_recv_backward( def send_backward_recv_backward(
input_tensor_grad, input_tensor_grad,
output_grad_shape, output_grad_shape,
recv_next=True,
prev_rank=None, prev_rank=None,
next_rank=None, next_rank=None,
dtype=torch.float, dtype=torch.float,
@ -389,7 +373,7 @@ def send_backward_recv_backward(
""" """
_, output_tensor_grad = _communicate( _, output_tensor_grad = _communicate(
object_send_prev=input_tensor_grad, object_send_prev=input_tensor_grad,
recv_next=recv_next, recv_next=output_grad_shape is not None,
recv_next_shape=output_grad_shape, recv_next_shape=output_grad_shape,
prev_rank=prev_rank, prev_rank=prev_rank,
next_rank=next_rank, next_rank=next_rank,
@ -404,8 +388,6 @@ def send_forward_backward_recv_forward_backward(
input_tensor_grad, input_tensor_grad,
input_tensor_shape, input_tensor_shape,
output_grad_shape, output_grad_shape,
recv_prev=True,
recv_next=True,
prev_rank=None, prev_rank=None,
next_rank=None, next_rank=None,
dtype=torch.float, dtype=torch.float,
@ -430,8 +412,8 @@ def send_forward_backward_recv_forward_backward(
input_tensor, output_tensor_grad = _communicate( input_tensor, output_tensor_grad = _communicate(
object_send_next=output_tensor, object_send_next=output_tensor,
object_send_prev=input_tensor_grad, object_send_prev=input_tensor_grad,
recv_prev=recv_prev, recv_prev=input_tensor_shape is not None,
recv_next=recv_next, recv_next=output_grad_shape is not None,
recv_prev_shape=input_tensor_shape, recv_prev_shape=input_tensor_shape,
recv_next_shape=output_grad_shape, recv_next_shape=output_grad_shape,
prev_rank=prev_rank, prev_rank=prev_rank,
@ -440,3 +422,159 @@ def send_forward_backward_recv_forward_backward(
scatter_gather_tensors=scatter_gather_tensors, scatter_gather_tensors=scatter_gather_tensors,
) )
return input_tensor, output_tensor_grad return input_tensor, output_tensor_grad
def send_forward_and_recv_next_forward_async(
output_tensor,
recv_prev_shape: Union[torch.Size, List[torch.Size]] = None,
dtype: torch.dtype = None,
scatter_gather_tensors=False,
):
"""send forward output to next rank and recv forward input from prev rank"""
reqs = []
tensor_recv_prev = None
# prepare send opreations
if output_tensor is not None:
next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)
output_tensor = process_object_to_send(output_tensor, scatter_gather_tensors)
if isinstance(output_tensor, torch.Tensor):
reqs.append(dist.P2POp(dist.isend, output_tensor, next_rank))
else:
for tensor_to_comm in output_tensor:
reqs.append(dist.P2POp(dist.isend, tensor_to_comm, next_rank))
# prepare receive opreations
if recv_prev_shape is not None:
prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE)
# create receive buffer
tensor_recv_prev, recv_prev_split = create_recv_buffer_with_shapes(
recv_prev_shape, dtype, scatter_gather_tensors
)
# generate async receive opterations
if isinstance(tensor_recv_prev, torch.Tensor):
reqs.append(dist.P2POp(dist.irecv, tensor_recv_prev, prev_rank))
else:
for tensor_to_comm in tensor_recv_prev:
reqs.append(dist.P2POp(dist.irecv, tensor_to_comm, prev_rank))
if len(reqs) > 0:
reqs = dist.batch_isend_irecv(reqs)
# return and do other things
yield
# check communication completed
for req in reqs:
req.wait()
# To protect against race condition when using batch_isend_irecv()
torch.cuda.synchronize()
# Process received data
if recv_prev_shape is not None and recv_prev_split:
if isinstance(tensor_recv_prev, torch.Tensor):
tensor_recv_prev = gather_split_1d_tensor(tensor_recv_prev).view(recv_prev_shape).requires_grad_()
else:
for index in range(len(tensor_recv_prev)):
tensor_recv_prev[index] = (
gather_split_1d_tensor(tensor_recv_prev[index]).view(recv_prev_shape[index]).requires_grad_()
)
yield tensor_recv_prev
def send_backward_and_recv_next_backward_async(
input_tensor,
recv_next_shape: Union[torch.Size, List[torch.Size]] = None,
dtype: torch.dtype = None,
scatter_gather_tensors=False,
):
reqs = []
tensor_recv_next = None
# prepare send opreations
if input_tensor is not None:
prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE)
input_tensor = process_object_to_send(input_tensor, scatter_gather_tensors)
if isinstance(input_tensor, torch.Tensor):
reqs.append(dist.P2POp(dist.isend, input_tensor, prev_rank))
else:
for tensor_to_comm in input_tensor:
reqs.append(dist.P2POp(dist.isend, tensor_to_comm, prev_rank))
# prepare receive opreations
if recv_next_shape is not None:
next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)
# create receive buffer
tensor_recv_next, recv_next_split = create_recv_buffer_with_shapes(
recv_next_shape, dtype, scatter_gather_tensors
)
# generate async receive opreations
if isinstance(tensor_recv_next, torch.Tensor):
reqs.append(dist.P2POp(dist.irecv, tensor_recv_next, next_rank))
else:
for tensor_to_comm in tensor_recv_next:
reqs.append(dist.P2POp(dist.irecv, tensor_to_comm, next_rank))
if len(reqs) > 0:
reqs = dist.batch_isend_irecv(reqs)
# return and do other things
yield
# check communication completed
for req in reqs:
req.wait()
# To protect against race condition when using batch_isend_irecv()
torch.cuda.synchronize()
# Process received data
if recv_next_shape is not None and recv_next_split:
if isinstance(tensor_recv_next, torch.Tensor):
tensor_recv_next = gather_split_1d_tensor(tensor_recv_next).view(recv_next_shape).requires_grad_()
else:
for index in range(len(tensor_recv_next)):
tensor_recv_next[index] = (
gather_split_1d_tensor(tensor_recv_next[index]).view(recv_next_shape[index]).requires_grad_()
)
yield tensor_recv_next
class AsynCommunicator:
def __init__(
self,
tensor_to_send: Union[torch.Tensor, List[torch.Tensor]],
recv_shape: Union[torch.Size, List[torch.Size]],
dtype: torch.dtype = None,
scatter_gather_tensors=False,
forward: bool = True,
) -> None:
self._need_receive = recv_shape is not None
if forward:
self._coroutine = send_forward_and_recv_next_forward_async(
tensor_to_send, recv_shape, dtype, scatter_gather_tensors
)
else:
self._coroutine = send_backward_and_recv_next_backward_async(
tensor_to_send, recv_shape, dtype, scatter_gather_tensors
)
@property
def need_receive(self) -> bool:
return self._need_receive
def start(self) -> None:
next(self._coroutine)
def wait_and_receive(self) -> Union[torch.Tensor, List[torch.Tensor]]:
received = next(self._coroutine)
self._coroutine.close()
return received

View File

@ -1,5 +1,6 @@
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/communication # adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/communication
from functools import wraps
from typing import List, Tuple, Union from typing import List, Tuple, Union
import torch import torch
@ -19,7 +20,7 @@ def send_meta_helper(obj, next_rank, tensor_kwargs):
dist.send(send_shape, next_rank) dist.send(send_shape, next_rank)
def send_obj_meta(obj, need_meta=True, next_rank=None) -> bool: def send_obj_meta(obj, next_rank=None):
"""Sends obj meta information before sending a specific obj. """Sends obj meta information before sending a specific obj.
Since the recipient must know the shape of the obj in p2p communications, Since the recipient must know the shape of the obj in p2p communications,
meta information of the obj should be sent before communications. This function meta information of the obj should be sent before communications. This function
@ -33,7 +34,6 @@ def send_obj_meta(obj, need_meta=True, next_rank=None) -> bool:
Returns: Returns:
bool: False bool: False
""" """
if need_meta:
if next_rank is None: if next_rank is None:
next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE) next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)
@ -48,8 +48,6 @@ def send_obj_meta(obj, need_meta=True, next_rank=None) -> bool:
for tensor_to_send in obj: for tensor_to_send in obj:
send_meta_helper(tensor_to_send, next_rank, tensor_kwargs) send_meta_helper(tensor_to_send, next_rank, tensor_kwargs)
return False
def recv_meta_helper(prev_rank, tensor_kwargs): def recv_meta_helper(prev_rank, tensor_kwargs):
recv_ndims = torch.empty((), **tensor_kwargs) recv_ndims = torch.empty((), **tensor_kwargs)
@ -59,7 +57,7 @@ def recv_meta_helper(prev_rank, tensor_kwargs):
return recv_shape return recv_shape
def recv_obj_meta(obj_shape, prev_rank=None) -> torch.Size: def recv_obj_meta(prev_rank=None) -> torch.Size:
"""Receives obj meta information before receiving a specific obj. """Receives obj meta information before receiving a specific obj.
Since the recipient must know the shape of the obj in p2p communications, Since the recipient must know the shape of the obj in p2p communications,
meta information of the obj should be received before communications. This function meta information of the obj should be received before communications. This function
@ -72,7 +70,6 @@ def recv_obj_meta(obj_shape, prev_rank=None) -> torch.Size:
Returns: Returns:
Union[:class:`torch.Size`, List[:class:`torch.Size`]]: The shape of the obj to be received. Union[:class:`torch.Size`, List[:class:`torch.Size`]]: The shape of the obj to be received.
""" """
if obj_shape is None:
if prev_rank is None: if prev_rank is None:
prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE) prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE)

View File

@ -73,6 +73,17 @@ class NaiveAMPModel(nn.Module):
input_ = input_.float() input_ = input_.float()
return input_ return input_
def convert_to_fp32(self, out):
"""Converts the output to fp32"""
if isinstance(out, Tensor):
out = self._convert_to_fp32(out)
elif isinstance(out, (tuple, list)):
out = [self._convert_to_fp32(val) for val in out]
elif isinstance(out, dict):
out = {key: self._convert_to_fp32(val) for key, val in out.items()}
return out
def _reduce_module_buffer(self): def _reduce_module_buffer(self):
""" """
All-reduces the buffers (e.g., running stats of batch normalization) across All-reduces the buffers (e.g., running stats of batch normalization) across
@ -121,10 +132,5 @@ class NaiveAMPModel(nn.Module):
out = self.model(*args, **kwargs) out = self.model(*args, **kwargs)
if self._output_to_fp32: if self._output_to_fp32:
if isinstance(out, Tensor): out = self.convert_to_fp32(out)
out = self._convert_to_fp32(out)
elif isinstance(out, (tuple, list)):
out = [self._convert_to_fp32(val) for val in out]
elif isinstance(out, dict):
out = {key: self._convert_to_fp32(val) for key, val in out.items()}
return out return out

View File

@ -1,5 +1,12 @@
from .base_scheduler import BaseScheduler from .base_scheduler import BaseScheduler, SchedulerHook, SchedulerMetricHook
from .no_pipeline_scheduler import NonPipelineScheduler from .no_pipeline_scheduler import NonPipelineScheduler
from .pipeline_scheduler import InterleavedPipelineScheduler, PipelineScheduler from .pipeline_scheduler import InterleavedPipelineScheduler, PipelineScheduler
__all__ = ["BaseScheduler", "NonPipelineScheduler", "InterleavedPipelineScheduler", "PipelineScheduler"] __all__ = [
"BaseScheduler",
"NonPipelineScheduler",
"InterleavedPipelineScheduler",
"PipelineScheduler",
"SchedulerHook",
"SchedulerMetricHook",
]

View File

@ -4,11 +4,12 @@
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/engine # adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/engine
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Callable, Iterable from typing import Any, Callable, Iterable, Optional
import torch import torch
from internlm.core.engine import Engine from internlm.core.engine import Engine
from internlm.utils.megatron_timers import megatron_timer as timer
class BaseScheduler(ABC): class BaseScheduler(ABC):
@ -112,3 +113,85 @@ class BaseScheduler(ABC):
'(which is auto-converted to tuple), list, tuple, or dict, ' \ '(which is auto-converted to tuple), list, tuple, or dict, ' \
'but got {type(outputs)} (model outputs) and {type(labels)} (labels)" 'but got {type(outputs)} (model outputs) and {type(labels)} (labels)"
) )
class SchedulerHook(ABC):
"""
Scheduler Hook.
"""
@abstractmethod
def before_forward(self, scheduler, inputs) -> None:
"""Actions before forward"""
@abstractmethod
def after_forward(self, scheduler, outputs) -> None:
"""Actions after forward"""
@abstractmethod
def before_criterion(self, scheduler, outputs, label) -> None:
"""Actions before criterion"""
@abstractmethod
def after_criterion(self, scheduler, loss) -> None:
"""Actions after criterion"""
@abstractmethod
def before_backward(self, scheduler, outputs, outputs_grad) -> None:
"""Actions before backward"""
@abstractmethod
def after_backward(self, scheduler, inputs_grad) -> None:
"""Actions after backward"""
@abstractmethod
def post_helper_func(self, scheduler, outputs, label) -> None:
"""A post helper function"""
class SchedulerMetricHook(SchedulerHook):
"""
Scheduler Metric Hook.
"""
def __init__(self, metric: Optional[Callable] = None, skip: bool = False) -> None:
self._post_func = metric
self._skip = skip
if skip:
# init timer only.
timer("fwd")
timer("bwd")
timer("cal_loss")
timer("post_fn")
def before_forward(self, scheduler, inputs) -> None:
if not self._skip:
timer("fwd").start()
def after_forward(self, scheduler, outputs) -> None:
if not self._skip:
timer("fwd").stop()
def before_criterion(self, scheduler, outputs, label) -> None:
if not self._skip:
timer("cal_loss").start()
def after_criterion(self, scheduler, loss) -> None:
if not self._skip:
timer("cal_loss").stop()
def before_backward(self, scheduler, outputs, outputs_grad) -> None:
if not self._skip:
timer("bwd").start()
def after_backward(self, scheduler, inputs_grad) -> None:
if not self._skip:
timer("bwd").stop()
def post_helper_func(self, scheduler, outputs, label) -> None:
if not self._skip:
timer("post_fn").start()
if self._post_func is not None:
self._post_func(outputs, label)
timer("post_fn").stop()

View File

@ -3,14 +3,14 @@
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/engine # adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/engine
from typing import Any, Callable, Iterable from typing import Any, Callable, Iterable, List, Optional
import torch import torch
from internlm.core.engine import Engine from internlm.core.engine import Engine
from internlm.utils.common import conditional_context from internlm.utils.common import conditional_context
from .base_scheduler import BaseScheduler from .base_scheduler import BaseScheduler, SchedulerHook
class NonPipelineScheduler(BaseScheduler): class NonPipelineScheduler(BaseScheduler):
@ -34,10 +34,17 @@ class NonPipelineScheduler(BaseScheduler):
return data, label return data, label
""" """
def __init__(self, data_process_func: Callable = None, gradient_accumulation_size: int = 1): def __init__(
self,
data_process_func: Callable = None,
gradient_accumulation_size: int = 1,
scheduler_hooks: Optional[List[SchedulerHook]] = None,
):
self._grad_accum_size = gradient_accumulation_size self._grad_accum_size = gradient_accumulation_size
self._grad_accum_offset = 0 self._grad_accum_offset = 0
self._hooks = scheduler_hooks
super().__init__(data_process_func) super().__init__(data_process_func)
def pre_processing(self, engine: Engine): def pre_processing(self, engine: Engine):
@ -48,6 +55,10 @@ class NonPipelineScheduler(BaseScheduler):
""" """
pass pass
def _call_hooks(self, func_name: str, *args, **kwargs) -> None:
for hook in self._hooks:
getattr(hook, func_name)(self, *args, **kwargs)
def _load_accum_batch(self, data: Any, label: Any): def _load_accum_batch(self, data: Any, label: Any):
"""Loads a batch of data and label for gradient accumulation. """Loads a batch of data and label for gradient accumulation.
@ -77,7 +88,6 @@ class NonPipelineScheduler(BaseScheduler):
forward_only: bool = False, forward_only: bool = False,
return_loss: bool = True, return_loss: bool = True,
scale_loss: int = 1, scale_loss: int = 1,
post_fn: Callable = None,
): ):
"""Trains one batch of data. """Trains one batch of data.
@ -89,23 +99,27 @@ class NonPipelineScheduler(BaseScheduler):
be executed. be executed.
return_loss (bool, optional): Loss will be returned if True. return_loss (bool, optional): Loss will be returned if True.
scale_loss (int, optional): The scale factor for the loss. scale_loss (int, optional): The scale factor for the loss.
post_fn (Callable, optional): Call back function after executing data forward output.
""" """
# forward # forward
with conditional_context(torch.no_grad(), enable=forward_only): with conditional_context(torch.no_grad(), enable=forward_only):
self._call_hooks("before_forward", data)
output = self._call_engine(engine, data) output = self._call_engine(engine, data)
self._call_hooks("after_forward", output)
if post_fn is not None: self._call_hooks("post_helper_func", output, label)
post_fn(output, label)
if return_loss: if return_loss:
self._call_hooks("before_criterion", output, label)
loss = self._call_engine_criterion(engine, output, label) loss = self._call_engine_criterion(engine, output, label)
self._call_hooks("after_criterion", loss)
loss /= scale_loss loss /= scale_loss
# backward # backward
if not forward_only: if not forward_only:
self._call_hooks("before_backward", None, None)
engine.backward(loss) engine.backward(loss)
self._call_hooks("after_backward", None)
if not return_loss: if not return_loss:
loss = None loss = None
@ -119,7 +133,6 @@ class NonPipelineScheduler(BaseScheduler):
forward_only: bool = False, forward_only: bool = False,
return_loss: bool = True, return_loss: bool = True,
return_output_label: bool = True, return_output_label: bool = True,
post_fn: Callable = None,
): ):
"""The process function that loads a batch of dataset and feeds it to the model. """The process function that loads a batch of dataset and feeds it to the model.
The returned labels and loss will None if :attr:`return_loss` is False. The returned labels and loss will None if :attr:`return_loss` is False.
@ -131,7 +144,6 @@ class NonPipelineScheduler(BaseScheduler):
If True, the model is run for the forward pass, else back propagation will be executed. If True, the model is run for the forward pass, else back propagation will be executed.
return_loss (bool, optional): Loss will be returned if True. return_loss (bool, optional): Loss will be returned if True.
return_output_label (bool, optional): Output and label will be returned if True. return_output_label (bool, optional): Output and label will be returned if True.
post_fn (Callable, optional): Call back function after executing data forward output.
Returns: Returns:
Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None. Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None.
@ -165,7 +177,7 @@ class NonPipelineScheduler(BaseScheduler):
_data, _label = self._load_accum_batch(data, label) _data, _label = self._load_accum_batch(data, label)
_output, _loss = self._train_one_batch( _output, _loss = self._train_one_batch(
_data, _label, engine, forward_only, return_loss, self._grad_accum_size, post_fn _data, _label, engine, forward_only, return_loss, self._grad_accum_size
) )
if return_loss: if return_loss:

File diff suppressed because it is too large Load Diff

View File

@ -3,7 +3,7 @@
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/initialize # adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/initialize
from typing import Callable, Iterable, Optional, Tuple from typing import Callable, Iterable, List, Optional, Tuple
from torch import nn from torch import nn
from torch.nn.modules.loss import _Loss from torch.nn.modules.loss import _Loss
@ -15,12 +15,13 @@ from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc from internlm.core.context import global_context as gpc
from internlm.core.engine import Engine from internlm.core.engine import Engine
from internlm.core.gradient_handler import PipelineSharedModuleGradientHandler from internlm.core.gradient_handler import PipelineSharedModuleGradientHandler
from internlm.core.scheduler.no_pipeline_scheduler import NonPipelineScheduler from internlm.core.scheduler import (
from internlm.core.scheduler.pipeline_scheduler import (
InterleavedPipelineScheduler, InterleavedPipelineScheduler,
NonPipelineScheduler,
PipelineScheduler, PipelineScheduler,
get_tensor_shape, SchedulerHook,
) )
from internlm.core.scheduler.pipeline_scheduler import get_tensor_shape
from internlm.core.trainer import Trainer from internlm.core.trainer import Trainer
from internlm.data.utils import unpack_data from internlm.data.utils import unpack_data
from internlm.solver.beta2_scheduler import Beta2Scheduler from internlm.solver.beta2_scheduler import Beta2Scheduler
@ -36,6 +37,7 @@ def initialize_trainer(
test_dataloader: Optional[Iterable] = None, test_dataloader: Optional[Iterable] = None,
lr_scheduler: Optional[_LRScheduler] = None, lr_scheduler: Optional[_LRScheduler] = None,
beta2_scheduler: Optional[Beta2Scheduler] = None, beta2_scheduler: Optional[Beta2Scheduler] = None,
scheduler_hooks: Optional[List[SchedulerHook]] = None,
) -> Tuple[Trainer, DataLoader, DataLoader, _LRScheduler]: ) -> Tuple[Trainer, DataLoader, DataLoader, _LRScheduler]:
"""Core function to wrap the essential training components with our functionality based on the config which is """Core function to wrap the essential training components with our functionality based on the config which is
loaded into gpc.config. loaded into gpc.config.
@ -92,12 +94,16 @@ def initialize_trainer(
if use_interleaved: if use_interleaved:
if isinstance(model, nn.Sequential): if isinstance(model, nn.Sequential):
model = nn.ModuleList([model]) model = nn.ModuleList([model])
communication_overlap = gpc.config.parallel["pipeline"].get("interleaved_overlap", False)
scheduler = InterleavedPipelineScheduler( scheduler = InterleavedPipelineScheduler(
num_microbatches=gpc.config.NUM_MICRO_BATCHES, num_microbatches=gpc.config.NUM_MICRO_BATCHES,
num_model_chunks=gpc.config.model.num_chunks, num_chunks=gpc.config.model.num_chunks,
dtype=gpc.config.model["dtype"], dtype=gpc.config.model["dtype"],
tensor_shape=tensor_shape, tensor_shape=tensor_shape,
scatter_gather_tensors=scatter_gather, scatter_gather_tensors=scatter_gather,
scheduler_hooks=scheduler_hooks,
communication_overlap=communication_overlap,
) )
else: else:
scheduler = PipelineScheduler( scheduler = PipelineScheduler(
@ -106,10 +112,13 @@ def initialize_trainer(
dtype=gpc.config.model["dtype"], dtype=gpc.config.model["dtype"],
tensor_shape=tensor_shape, tensor_shape=tensor_shape,
scatter_gather_tensors=scatter_gather, scatter_gather_tensors=scatter_gather,
scheduler_hooks=scheduler_hooks,
) )
else: else:
scheduler = NonPipelineScheduler( scheduler = NonPipelineScheduler(
data_process_func=data_fn, gradient_accumulation_size=gpc.config.data.gradient_accumulation data_process_func=data_fn,
gradient_accumulation_size=gpc.config.data.gradient_accumulation,
scheduler_hooks=scheduler_hooks,
) )
# initialize engine for trainer # initialize engine for trainer

View File

@ -7,40 +7,47 @@ from tqdm import tqdm
from internlm.core.context import ParallelMode from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc from internlm.core.context import global_context as gpc
from internlm.model.metrics import AccPerplex from internlm.model.metrics import AccPerplex
from internlm.core.scheduler import SchedulerMetricHook
@contextmanager @contextmanager
def switch_evaluation_no_pipeline_scheduler(trainer, grad_accum_size, grad_accum_batch_size): def switch_evaluation_no_pipeline_scheduler(trainer, grad_accum_size, grad_accum_batch_size, metric_hook_list):
if not gpc.is_using_pp(): if not gpc.is_using_pp():
prev_data_process_func = trainer.schedule.data_process_func prev_data_process_func = trainer.schedule.data_process_func
prev_grad_accum_size = trainer.schedule._grad_accum_size prev_grad_accum_size = trainer.schedule._grad_accum_size
prev_grad_accum_batch_size = trainer.schedule._grad_accum_batch_size prev_grad_accum_batch_size = trainer.schedule._grad_accum_batch_size
prev_metric_hooks = trainer.schedule._hooks
try: try:
trainer.schedule.data_process_func = None trainer.schedule.data_process_func = None
trainer.schedule._grad_accum_size = grad_accum_size trainer.schedule._grad_accum_size = grad_accum_size
trainer.schedule._grad_accum_batch_size = grad_accum_batch_size trainer.schedule._grad_accum_batch_size = grad_accum_batch_size
trainer.schedule._hooks = metric_hook_list
yield yield
finally: finally:
trainer.schedule.data_process_func = prev_data_process_func trainer.schedule.data_process_func = prev_data_process_func
trainer.schedule._grad_accum_size = prev_grad_accum_size trainer.schedule._grad_accum_size = prev_grad_accum_size
trainer.schedule._grad_accum_batch_size = prev_grad_accum_batch_size trainer.schedule._grad_accum_batch_size = prev_grad_accum_batch_size
trainer.schedule._hooks = prev_metric_hooks
@contextmanager @contextmanager
def switch_evaluation_pipeline_scheduler(trainer, num_microbatches, tensor_shape): def switch_evaluation_pipeline_scheduler(trainer, num_microbatches, tensor_shape, metric_hook_list):
if gpc.is_using_pp(): if gpc.is_using_pp():
pre_data_process_func = trainer.schedule.data_process_func pre_data_process_func = trainer.schedule.data_process_func
prev_num_microbatches = trainer.schedule.num_microbatches prev_num_microbatches = trainer.schedule.num_microbatches
prev_tensor_shape = trainer.schedule.tensor_shape prev_tensor_shape = trainer.schedule.tensor_shape
prev_metric_hooks = trainer.schedule._hooks
try: try:
trainer.schedule.data_process_func = None trainer.schedule.data_process_func = None
trainer.schedule.num_microbatches = num_microbatches trainer.schedule.num_microbatches = num_microbatches
trainer.schedule.tensor_shape = tensor_shape trainer.schedule.tensor_shape = tensor_shape
trainer.schedule._hooks = metric_hook_list
yield yield
finally: finally:
trainer.schedule.data_process_func = pre_data_process_func trainer.schedule.data_process_func = pre_data_process_func
trainer.schedule.num_microbatches = prev_num_microbatches trainer.schedule.num_microbatches = prev_num_microbatches
trainer.schedule.tensor_shape = prev_tensor_shape trainer.schedule.tensor_shape = prev_tensor_shape
trainer.schedule._hooks = prev_metric_hooks
def evaluate_on_val_dls( def evaluate_on_val_dls(
@ -49,7 +56,6 @@ def evaluate_on_val_dls(
writer, writer,
logger, logger,
step_count, step_count,
tokenizer=None,
update_panel: bool = False, update_panel: bool = False,
): ):
torch.cuda.empty_cache() torch.cuda.empty_cache()
@ -66,8 +72,9 @@ def evaluate_on_val_dls(
device=torch.cuda.current_device(), device=torch.cuda.current_device(),
tp_pg=gpc.get_group(ParallelMode.TENSOR), tp_pg=gpc.get_group(ParallelMode.TENSOR),
dp_pg=gpc.get_group(ParallelMode.DATA), dp_pg=gpc.get_group(ParallelMode.DATA),
tokenizer=tokenizer,
) )
val_sche_metric_hook = SchedulerMetricHook(metric=val_metric)
val_loss = 0 val_loss = 0
val_idx = -1 val_idx = -1
for val_idx, batch in tqdm( for val_idx, batch in tqdm(
@ -88,10 +95,13 @@ def evaluate_on_val_dls(
) )
with switch_evaluation_pipeline_scheduler( with switch_evaluation_pipeline_scheduler(
trainer=trainer, num_microbatches=num_microbatches, tensor_shape=tensor_shape trainer=trainer,
num_microbatches=num_microbatches,
tensor_shape=tensor_shape,
metric_hook_list=[val_sche_metric_hook],
): ):
_, _, loss = trainer.execute_schedule( _, _, loss = trainer.execute_schedule(
batch, forward_only=True, return_loss=True, return_output_label=False, post_fn=val_metric batch, forward_only=True, return_loss=True, return_output_label=False
) )
else: else:
total_val_bsz = len(batch[1]) total_val_bsz = len(batch[1])
@ -100,38 +110,42 @@ def evaluate_on_val_dls(
grad_accum_batch_size = data_cfg.micro_bsz grad_accum_batch_size = data_cfg.micro_bsz
with switch_evaluation_no_pipeline_scheduler( with switch_evaluation_no_pipeline_scheduler(
trainer=trainer, grad_accum_size=grad_accum_size, grad_accum_batch_size=grad_accum_batch_size trainer=trainer,
grad_accum_size=grad_accum_size,
grad_accum_batch_size=grad_accum_batch_size,
metric_hook_list=[val_sche_metric_hook],
): ):
_, _, loss = trainer.execute_schedule( _, _, loss = trainer.execute_schedule(
batch, forward_only=True, return_loss=True, return_output_label=False, post_fn=val_metric batch, forward_only=True, return_loss=True, return_output_label=False
) )
if verbose: if verbose:
val_loss += loss.item() val_loss += loss.item()
assert val_idx != -1 assert val_idx != -1
dist.barrier() dist.barrier()
val_res = val_metric.get_metric()
val_res = val_metric.get_metric()
if verbose and len(val_dl) != 0: if verbose and len(val_dl) != 0:
val_loss = val_loss / (val_idx + 1 + 1e-6) val_loss = val_loss / (val_idx + 1 + 1e-6)
infos = { infos = {
"step": step_count,
f"val/{val_name}_loss": val_loss, f"val/{val_name}_loss": val_loss,
f"val/{val_name}_acc": val_res["acc"], f"val/{val_name}_acc": val_res["acc"],
f"val/{val_name}_plex": val_res["perplexity"], f"val/{val_name}_plex": val_res["perplexity"],
} }
val_metric = {
for key, value in infos.items():
writer.add_scalar(key=key, value=value, step=step_count)
if update_panel:
logger.info(
f"Validation on {val_name}: " + " ".join([f"{key}={value}" for key, value in infos.items()]),
extra={
"step": step_count, "step": step_count,
"val_loss": val_loss, "val_loss": val_loss,
"val_acc": val_res["acc"], "val_acc": val_res["acc"],
"val_perplexity": val_res["perplexity"], "val_perplexity": val_res["perplexity"],
} },
for key, value in infos.items():
writer.add_scalar(key=key, value=value, step=step_count)
infos["step"] = step_count
if update_panel:
logger.info(
f"Validation on {val_name}: " + " ".join([f"{key}={value}" for key, value in infos.items()]),
extra=val_metric,
) )
else: else:
logger.info( logger.info(

View File

@ -16,6 +16,7 @@ import internlm
from internlm.core.context import ParallelMode from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc from internlm.core.context import global_context as gpc
from internlm.core.naive_amp import NaiveAMPModel from internlm.core.naive_amp import NaiveAMPModel
from internlm.core.scheduler import SchedulerMetricHook
from internlm.core.trainer import TrainState from internlm.core.trainer import TrainState
from internlm.data.batch_sampler import StaticBatchSampler, get_dpsampler_dataloader from internlm.data.batch_sampler import StaticBatchSampler, get_dpsampler_dataloader
from internlm.data.collaters import jsonl_ds_collate_fn, packed_collate_fn from internlm.data.collaters import jsonl_ds_collate_fn, packed_collate_fn
@ -109,6 +110,19 @@ def initialize_model():
""" """
model = MODEL_INITIALIZER.get_module(module_name=gpc.config.model_type)(**(gpc.config.model)) model = MODEL_INITIALIZER.get_module(module_name=gpc.config.model_type)(**(gpc.config.model))
if isinstance(model, nn.ModuleList):
model = nn.ModuleList(
[
NaiveAMPModel(
model=_m,
output_to_fp32=False, # manually controlled by interleaved pipleline scheduler
dtype=gpc.config.model.get("dtype", torch.half),
sync_buffer=False,
)
for _m in model
]
)
else:
model = NaiveAMPModel( model = NaiveAMPModel(
model=model, model=model,
output_to_fp32=is_no_pp_or_last_stage(), output_to_fp32=is_no_pp_or_last_stage(),
@ -500,19 +514,6 @@ def main(args):
if load_optimizer: if load_optimizer:
load_optimizer_checkpoint(load_resume_ckpt_folder, optimizer) load_optimizer_checkpoint(load_resume_ckpt_folder, optimizer)
# initialize trainer
trainer, train_dl, _, _ = internlm.initialize_trainer(
model=model,
optimizer=optimizer,
criterion=criterion,
train_dataloader=train_dl,
lr_scheduler=lr_scheduler,
beta2_scheduler=beta2_scheduler,
)
# initialize the batch skipper
batch_skipper = BatchSkipper(skip_batches)
# initialize metric for calculating accuracy and perplexity # initialize metric for calculating accuracy and perplexity
metric = AccPerplex( metric = AccPerplex(
device=torch.cuda.current_device(), device=torch.cuda.current_device(),
@ -521,6 +522,27 @@ def main(args):
dataset_types=dataset_types, dataset_types=dataset_types,
) )
# initialize trainer
scheduler_hooks = [
SchedulerMetricHook(
metric=metric,
skip=gpc.is_using_pp() and gpc.config.parallel["pipeline"].get("interleaved_overlap", False),
),
]
trainer, train_dl, _, _ = internlm.initialize_trainer(
model=model,
optimizer=optimizer,
criterion=criterion,
train_dataloader=train_dl,
lr_scheduler=lr_scheduler,
beta2_scheduler=beta2_scheduler,
scheduler_hooks=scheduler_hooks,
)
# initialize the batch skipper
batch_skipper = BatchSkipper(skip_batches)
trainer.train() trainer.train()
# transfer the train data loader into train data iterator # transfer the train data loader into train data iterator
@ -558,9 +580,7 @@ def main(args):
# do forward and backward # do forward and backward
timer("fwd-bwd").start() timer("fwd-bwd").start()
_, _, loss = trainer.execute_schedule( _, _, loss = trainer.execute_schedule(batch, forward_only=False, return_loss=True, return_output_label=False)
batch, forward_only=False, return_loss=True, return_output_label=False, post_fn=metric
)
timer("fwd-bwd").stop() timer("fwd-bwd").stop()
# update parameters, and returns (success_update, grad_norm) # update parameters, and returns (success_update, grad_norm)