refactor(scheduler): rewrite pipeline scheduler (#138)

* refactor(scheduler): rewrite pipeline scheduler

* fix(*): fix pipeline scheduler bugs

* fix(*): fix merge bug

* feat(*): update codes with todo tag

* feat(*): add comments

* feat(internlm/core/scheduler): update recv_prev/next logic

* feat(utils/evaluation.py): update sche metric hook for valid

---------

Co-authored-by: huangting.p <huangting@sensetime.com>
pull/155/head
cx 2023-08-03 11:48:12 +08:00 committed by GitHub
parent d67be17f96
commit 0268d8eda1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 1412 additions and 712 deletions

View File

@ -117,6 +117,7 @@ model = dict(
norm_type="rmsnorm",
layer_norm_epsilon=1e-5,
use_flash_attn=True,
num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used.
)
"""
zero1 parallel:
@ -125,12 +126,14 @@ zero1 parallel:
2. if zero1 == 1, zero is not used, and all dp groups retain the full amount of model parameters.
3. zero1 > 1 and zero1 <= dp world size, the world size of zero is a subset of dp world size.
For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8.
pipeline parallel: pipeline parallel size.
pipeline parallel (dict):
1. size: int, the size of pipeline parallel.
2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler.
tensor parallel: tensor parallel size, usually the number of GPUs per node.
"""
parallel = dict(
zero1=8,
pipeline=2,
pipeline=dict(size=1, interleaved_overlap=True),
)
cudnn_deterministic = False

View File

@ -1,10 +1,13 @@
from .p2p import (
AsynCommunicator,
recv_backward,
recv_forward,
send_backward,
send_backward_and_recv_next_backward_async,
send_backward_recv_backward,
send_backward_recv_forward,
send_forward,
send_forward_and_recv_next_forward_async,
send_forward_backward_recv_forward_backward,
send_forward_recv_backward,
send_forward_recv_forward,
@ -23,4 +26,7 @@ __all__ = [
"recv_forward",
"send_obj_meta",
"recv_obj_meta",
"send_backward_and_recv_next_backward_async",
"send_forward_and_recv_next_forward_async",
"AsynCommunicator",
]

View File

@ -207,16 +207,13 @@ def recv_forward(
Returns:
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input tensor or input tensor list.
"""
if gpc.is_pipeline_first_stage():
input_tensor = None
else:
input_tensor, _ = _communicate(
recv_prev=True,
recv_prev_shape=input_tensor_shape,
prev_rank=prev_rank,
dtype=dtype,
scatter_gather_tensors=scatter_gather_tensors,
)
input_tensor, _ = _communicate(
recv_prev=True,
recv_prev_shape=input_tensor_shape,
prev_rank=prev_rank,
dtype=dtype,
scatter_gather_tensors=scatter_gather_tensors,
)
return input_tensor
@ -233,16 +230,13 @@ def recv_backward(
Returns:
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input gradient tensor or gradident tensor list.
"""
if gpc.is_pipeline_last_stage():
output_tensor_grad = None
else:
_, output_tensor_grad = _communicate(
recv_next=True,
recv_next_shape=output_grad_shape,
next_rank=next_rank,
dtype=dtype,
scatter_gather_tensors=scatter_gather_tensors,
)
_, output_tensor_grad = _communicate(
recv_next=True,
recv_next_shape=output_grad_shape,
next_rank=next_rank,
dtype=dtype,
scatter_gather_tensors=scatter_gather_tensors,
)
return output_tensor_grad
@ -253,8 +247,7 @@ def send_forward(output_tensor, next_rank=None, scatter_gather_tensors=False) ->
output_tensor (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent.
next_rank (int, optional): The rank of the recipient of the tensor.
"""
if not gpc.is_pipeline_last_stage():
_communicate(object_send_next=output_tensor, next_rank=next_rank, scatter_gather_tensors=scatter_gather_tensors)
_communicate(object_send_next=output_tensor, next_rank=next_rank, scatter_gather_tensors=scatter_gather_tensors)
def send_backward(input_tensor_grad, prev_rank=None, scatter_gather_tensors=False) -> None:
@ -264,14 +257,12 @@ def send_backward(input_tensor_grad, prev_rank=None, scatter_gather_tensors=Fals
input_tensor_grad (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent
prev_rank (int, optional): The rank of the recipient of the tensor
"""
if not gpc.is_pipeline_first_stage():
_communicate(
object_send_prev=input_tensor_grad, prev_rank=prev_rank, scatter_gather_tensors=scatter_gather_tensors
)
_communicate(object_send_prev=input_tensor_grad, prev_rank=prev_rank, scatter_gather_tensors=scatter_gather_tensors)
def send_forward_recv_backward(
output_tensor, output_grad_shape, recv_next=True, next_rank=None, dtype=torch.float, scatter_gather_tensors=False
output_tensor, output_grad_shape, next_rank=None, dtype=torch.float, scatter_gather_tensors=False
) -> Union[torch.Tensor, List[torch.Tensor]]:
"""Batched communication operation. Sends the input tensor to the
next stage in pipeline, while receives the gradient tensor from the
@ -285,24 +276,21 @@ def send_forward_recv_backward(
Returns:
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input gradient tensor.
"""
if gpc.is_pipeline_last_stage():
output_tensor_grad = None
else:
_, output_tensor_grad = _communicate(
object_send_next=output_tensor,
recv_next=recv_next,
recv_next_shape=output_grad_shape,
next_rank=next_rank,
dtype=dtype,
scatter_gather_tensors=scatter_gather_tensors,
)
_, output_tensor_grad = _communicate(
object_send_next=output_tensor,
recv_next=output_grad_shape is not None,
recv_next_shape=output_grad_shape,
next_rank=next_rank,
dtype=dtype,
scatter_gather_tensors=scatter_gather_tensors,
)
return output_tensor_grad
def send_backward_recv_forward(
input_tensor_grad,
input_tensor_shape,
recv_prev=True,
prev_rank=None,
dtype=torch.float,
scatter_gather_tensors=False,
@ -319,24 +307,21 @@ def send_backward_recv_forward(
Returns:
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input tensor.
"""
if gpc.is_pipeline_first_stage():
input_tensor = None
else:
input_tensor, _ = _communicate(
object_send_prev=input_tensor_grad,
recv_prev=recv_prev,
recv_prev_shape=input_tensor_shape,
prev_rank=prev_rank,
dtype=dtype,
scatter_gather_tensors=scatter_gather_tensors,
)
input_tensor, _ = _communicate(
object_send_prev=input_tensor_grad,
recv_prev=input_tensor_shape is not None,
recv_prev_shape=input_tensor_shape,
prev_rank=prev_rank,
dtype=dtype,
scatter_gather_tensors=scatter_gather_tensors,
)
return input_tensor
def send_forward_recv_forward(
output_tensor,
input_tensor_shape,
recv_prev=True,
prev_rank=None,
next_rank=None,
dtype=torch.float,
@ -356,7 +341,7 @@ def send_forward_recv_forward(
"""
input_tensor, _ = _communicate(
object_send_next=output_tensor,
recv_prev=recv_prev,
recv_prev=input_tensor_shape is not None,
recv_prev_shape=input_tensor_shape,
prev_rank=prev_rank,
next_rank=next_rank,
@ -369,7 +354,6 @@ def send_forward_recv_forward(
def send_backward_recv_backward(
input_tensor_grad,
output_grad_shape,
recv_next=True,
prev_rank=None,
next_rank=None,
dtype=torch.float,
@ -389,7 +373,7 @@ def send_backward_recv_backward(
"""
_, output_tensor_grad = _communicate(
object_send_prev=input_tensor_grad,
recv_next=recv_next,
recv_next=output_grad_shape is not None,
recv_next_shape=output_grad_shape,
prev_rank=prev_rank,
next_rank=next_rank,
@ -404,8 +388,6 @@ def send_forward_backward_recv_forward_backward(
input_tensor_grad,
input_tensor_shape,
output_grad_shape,
recv_prev=True,
recv_next=True,
prev_rank=None,
next_rank=None,
dtype=torch.float,
@ -430,8 +412,8 @@ def send_forward_backward_recv_forward_backward(
input_tensor, output_tensor_grad = _communicate(
object_send_next=output_tensor,
object_send_prev=input_tensor_grad,
recv_prev=recv_prev,
recv_next=recv_next,
recv_prev=input_tensor_shape is not None,
recv_next=output_grad_shape is not None,
recv_prev_shape=input_tensor_shape,
recv_next_shape=output_grad_shape,
prev_rank=prev_rank,
@ -440,3 +422,159 @@ def send_forward_backward_recv_forward_backward(
scatter_gather_tensors=scatter_gather_tensors,
)
return input_tensor, output_tensor_grad
def send_forward_and_recv_next_forward_async(
output_tensor,
recv_prev_shape: Union[torch.Size, List[torch.Size]] = None,
dtype: torch.dtype = None,
scatter_gather_tensors=False,
):
"""send forward output to next rank and recv forward input from prev rank"""
reqs = []
tensor_recv_prev = None
# prepare send opreations
if output_tensor is not None:
next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)
output_tensor = process_object_to_send(output_tensor, scatter_gather_tensors)
if isinstance(output_tensor, torch.Tensor):
reqs.append(dist.P2POp(dist.isend, output_tensor, next_rank))
else:
for tensor_to_comm in output_tensor:
reqs.append(dist.P2POp(dist.isend, tensor_to_comm, next_rank))
# prepare receive opreations
if recv_prev_shape is not None:
prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE)
# create receive buffer
tensor_recv_prev, recv_prev_split = create_recv_buffer_with_shapes(
recv_prev_shape, dtype, scatter_gather_tensors
)
# generate async receive opterations
if isinstance(tensor_recv_prev, torch.Tensor):
reqs.append(dist.P2POp(dist.irecv, tensor_recv_prev, prev_rank))
else:
for tensor_to_comm in tensor_recv_prev:
reqs.append(dist.P2POp(dist.irecv, tensor_to_comm, prev_rank))
if len(reqs) > 0:
reqs = dist.batch_isend_irecv(reqs)
# return and do other things
yield
# check communication completed
for req in reqs:
req.wait()
# To protect against race condition when using batch_isend_irecv()
torch.cuda.synchronize()
# Process received data
if recv_prev_shape is not None and recv_prev_split:
if isinstance(tensor_recv_prev, torch.Tensor):
tensor_recv_prev = gather_split_1d_tensor(tensor_recv_prev).view(recv_prev_shape).requires_grad_()
else:
for index in range(len(tensor_recv_prev)):
tensor_recv_prev[index] = (
gather_split_1d_tensor(tensor_recv_prev[index]).view(recv_prev_shape[index]).requires_grad_()
)
yield tensor_recv_prev
def send_backward_and_recv_next_backward_async(
input_tensor,
recv_next_shape: Union[torch.Size, List[torch.Size]] = None,
dtype: torch.dtype = None,
scatter_gather_tensors=False,
):
reqs = []
tensor_recv_next = None
# prepare send opreations
if input_tensor is not None:
prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE)
input_tensor = process_object_to_send(input_tensor, scatter_gather_tensors)
if isinstance(input_tensor, torch.Tensor):
reqs.append(dist.P2POp(dist.isend, input_tensor, prev_rank))
else:
for tensor_to_comm in input_tensor:
reqs.append(dist.P2POp(dist.isend, tensor_to_comm, prev_rank))
# prepare receive opreations
if recv_next_shape is not None:
next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)
# create receive buffer
tensor_recv_next, recv_next_split = create_recv_buffer_with_shapes(
recv_next_shape, dtype, scatter_gather_tensors
)
# generate async receive opreations
if isinstance(tensor_recv_next, torch.Tensor):
reqs.append(dist.P2POp(dist.irecv, tensor_recv_next, next_rank))
else:
for tensor_to_comm in tensor_recv_next:
reqs.append(dist.P2POp(dist.irecv, tensor_to_comm, next_rank))
if len(reqs) > 0:
reqs = dist.batch_isend_irecv(reqs)
# return and do other things
yield
# check communication completed
for req in reqs:
req.wait()
# To protect against race condition when using batch_isend_irecv()
torch.cuda.synchronize()
# Process received data
if recv_next_shape is not None and recv_next_split:
if isinstance(tensor_recv_next, torch.Tensor):
tensor_recv_next = gather_split_1d_tensor(tensor_recv_next).view(recv_next_shape).requires_grad_()
else:
for index in range(len(tensor_recv_next)):
tensor_recv_next[index] = (
gather_split_1d_tensor(tensor_recv_next[index]).view(recv_next_shape[index]).requires_grad_()
)
yield tensor_recv_next
class AsynCommunicator:
def __init__(
self,
tensor_to_send: Union[torch.Tensor, List[torch.Tensor]],
recv_shape: Union[torch.Size, List[torch.Size]],
dtype: torch.dtype = None,
scatter_gather_tensors=False,
forward: bool = True,
) -> None:
self._need_receive = recv_shape is not None
if forward:
self._coroutine = send_forward_and_recv_next_forward_async(
tensor_to_send, recv_shape, dtype, scatter_gather_tensors
)
else:
self._coroutine = send_backward_and_recv_next_backward_async(
tensor_to_send, recv_shape, dtype, scatter_gather_tensors
)
@property
def need_receive(self) -> bool:
return self._need_receive
def start(self) -> None:
next(self._coroutine)
def wait_and_receive(self) -> Union[torch.Tensor, List[torch.Tensor]]:
received = next(self._coroutine)
self._coroutine.close()
return received

View File

@ -1,5 +1,6 @@
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/communication
from functools import wraps
from typing import List, Tuple, Union
import torch
@ -19,7 +20,7 @@ def send_meta_helper(obj, next_rank, tensor_kwargs):
dist.send(send_shape, next_rank)
def send_obj_meta(obj, need_meta=True, next_rank=None) -> bool:
def send_obj_meta(obj, next_rank=None):
"""Sends obj meta information before sending a specific obj.
Since the recipient must know the shape of the obj in p2p communications,
meta information of the obj should be sent before communications. This function
@ -33,22 +34,19 @@ def send_obj_meta(obj, need_meta=True, next_rank=None) -> bool:
Returns:
bool: False
"""
if need_meta:
if next_rank is None:
next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)
if next_rank is None:
next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)
tensor_kwargs = {"dtype": torch.long, "device": get_current_device()}
if isinstance(obj, torch.Tensor):
send_obj_nums = torch.tensor(1, **tensor_kwargs)
dist.send(send_obj_nums, next_rank)
send_meta_helper(obj, next_rank, tensor_kwargs)
else:
send_obj_nums = torch.tensor(len(obj), **tensor_kwargs)
dist.send(send_obj_nums, next_rank)
for tensor_to_send in obj:
send_meta_helper(tensor_to_send, next_rank, tensor_kwargs)
return False
tensor_kwargs = {"dtype": torch.long, "device": get_current_device()}
if isinstance(obj, torch.Tensor):
send_obj_nums = torch.tensor(1, **tensor_kwargs)
dist.send(send_obj_nums, next_rank)
send_meta_helper(obj, next_rank, tensor_kwargs)
else:
send_obj_nums = torch.tensor(len(obj), **tensor_kwargs)
dist.send(send_obj_nums, next_rank)
for tensor_to_send in obj:
send_meta_helper(tensor_to_send, next_rank, tensor_kwargs)
def recv_meta_helper(prev_rank, tensor_kwargs):
@ -59,7 +57,7 @@ def recv_meta_helper(prev_rank, tensor_kwargs):
return recv_shape
def recv_obj_meta(obj_shape, prev_rank=None) -> torch.Size:
def recv_obj_meta(prev_rank=None) -> torch.Size:
"""Receives obj meta information before receiving a specific obj.
Since the recipient must know the shape of the obj in p2p communications,
meta information of the obj should be received before communications. This function
@ -72,21 +70,20 @@ def recv_obj_meta(obj_shape, prev_rank=None) -> torch.Size:
Returns:
Union[:class:`torch.Size`, List[:class:`torch.Size`]]: The shape of the obj to be received.
"""
if obj_shape is None:
if prev_rank is None:
prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE)
if prev_rank is None:
prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE)
tensor_kwargs = {"dtype": torch.long, "device": get_current_device()}
recv_obj_nums = torch.empty((), **tensor_kwargs)
dist.recv(recv_obj_nums, prev_rank)
if recv_obj_nums.item() == 1:
tensor_kwargs = {"dtype": torch.long, "device": get_current_device()}
recv_obj_nums = torch.empty((), **tensor_kwargs)
dist.recv(recv_obj_nums, prev_rank)
if recv_obj_nums.item() == 1:
recv_shape = recv_meta_helper(prev_rank, tensor_kwargs)
obj_shape = torch.Size(recv_shape)
else:
obj_shape = []
for _ in range(recv_obj_nums.item()):
recv_shape = recv_meta_helper(prev_rank, tensor_kwargs)
obj_shape = torch.Size(recv_shape)
else:
obj_shape = []
for _ in range(recv_obj_nums.item()):
recv_shape = recv_meta_helper(prev_rank, tensor_kwargs)
obj_shape.append(torch.Size(recv_shape))
obj_shape.append(torch.Size(recv_shape))
return obj_shape

View File

@ -73,6 +73,17 @@ class NaiveAMPModel(nn.Module):
input_ = input_.float()
return input_
def convert_to_fp32(self, out):
"""Converts the output to fp32"""
if isinstance(out, Tensor):
out = self._convert_to_fp32(out)
elif isinstance(out, (tuple, list)):
out = [self._convert_to_fp32(val) for val in out]
elif isinstance(out, dict):
out = {key: self._convert_to_fp32(val) for key, val in out.items()}
return out
def _reduce_module_buffer(self):
"""
All-reduces the buffers (e.g., running stats of batch normalization) across
@ -121,10 +132,5 @@ class NaiveAMPModel(nn.Module):
out = self.model(*args, **kwargs)
if self._output_to_fp32:
if isinstance(out, Tensor):
out = self._convert_to_fp32(out)
elif isinstance(out, (tuple, list)):
out = [self._convert_to_fp32(val) for val in out]
elif isinstance(out, dict):
out = {key: self._convert_to_fp32(val) for key, val in out.items()}
out = self.convert_to_fp32(out)
return out

View File

@ -1,5 +1,12 @@
from .base_scheduler import BaseScheduler
from .base_scheduler import BaseScheduler, SchedulerHook, SchedulerMetricHook
from .no_pipeline_scheduler import NonPipelineScheduler
from .pipeline_scheduler import InterleavedPipelineScheduler, PipelineScheduler
__all__ = ["BaseScheduler", "NonPipelineScheduler", "InterleavedPipelineScheduler", "PipelineScheduler"]
__all__ = [
"BaseScheduler",
"NonPipelineScheduler",
"InterleavedPipelineScheduler",
"PipelineScheduler",
"SchedulerHook",
"SchedulerMetricHook",
]

View File

@ -4,11 +4,12 @@
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/engine
from abc import ABC, abstractmethod
from typing import Any, Callable, Iterable
from typing import Any, Callable, Iterable, Optional
import torch
from internlm.core.engine import Engine
from internlm.utils.megatron_timers import megatron_timer as timer
class BaseScheduler(ABC):
@ -112,3 +113,85 @@ class BaseScheduler(ABC):
'(which is auto-converted to tuple), list, tuple, or dict, ' \
'but got {type(outputs)} (model outputs) and {type(labels)} (labels)"
)
class SchedulerHook(ABC):
"""
Scheduler Hook.
"""
@abstractmethod
def before_forward(self, scheduler, inputs) -> None:
"""Actions before forward"""
@abstractmethod
def after_forward(self, scheduler, outputs) -> None:
"""Actions after forward"""
@abstractmethod
def before_criterion(self, scheduler, outputs, label) -> None:
"""Actions before criterion"""
@abstractmethod
def after_criterion(self, scheduler, loss) -> None:
"""Actions after criterion"""
@abstractmethod
def before_backward(self, scheduler, outputs, outputs_grad) -> None:
"""Actions before backward"""
@abstractmethod
def after_backward(self, scheduler, inputs_grad) -> None:
"""Actions after backward"""
@abstractmethod
def post_helper_func(self, scheduler, outputs, label) -> None:
"""A post helper function"""
class SchedulerMetricHook(SchedulerHook):
"""
Scheduler Metric Hook.
"""
def __init__(self, metric: Optional[Callable] = None, skip: bool = False) -> None:
self._post_func = metric
self._skip = skip
if skip:
# init timer only.
timer("fwd")
timer("bwd")
timer("cal_loss")
timer("post_fn")
def before_forward(self, scheduler, inputs) -> None:
if not self._skip:
timer("fwd").start()
def after_forward(self, scheduler, outputs) -> None:
if not self._skip:
timer("fwd").stop()
def before_criterion(self, scheduler, outputs, label) -> None:
if not self._skip:
timer("cal_loss").start()
def after_criterion(self, scheduler, loss) -> None:
if not self._skip:
timer("cal_loss").stop()
def before_backward(self, scheduler, outputs, outputs_grad) -> None:
if not self._skip:
timer("bwd").start()
def after_backward(self, scheduler, inputs_grad) -> None:
if not self._skip:
timer("bwd").stop()
def post_helper_func(self, scheduler, outputs, label) -> None:
if not self._skip:
timer("post_fn").start()
if self._post_func is not None:
self._post_func(outputs, label)
timer("post_fn").stop()

View File

@ -3,14 +3,14 @@
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/engine
from typing import Any, Callable, Iterable
from typing import Any, Callable, Iterable, List, Optional
import torch
from internlm.core.engine import Engine
from internlm.utils.common import conditional_context
from .base_scheduler import BaseScheduler
from .base_scheduler import BaseScheduler, SchedulerHook
class NonPipelineScheduler(BaseScheduler):
@ -34,10 +34,17 @@ class NonPipelineScheduler(BaseScheduler):
return data, label
"""
def __init__(self, data_process_func: Callable = None, gradient_accumulation_size: int = 1):
def __init__(
self,
data_process_func: Callable = None,
gradient_accumulation_size: int = 1,
scheduler_hooks: Optional[List[SchedulerHook]] = None,
):
self._grad_accum_size = gradient_accumulation_size
self._grad_accum_offset = 0
self._hooks = scheduler_hooks
super().__init__(data_process_func)
def pre_processing(self, engine: Engine):
@ -48,6 +55,10 @@ class NonPipelineScheduler(BaseScheduler):
"""
pass
def _call_hooks(self, func_name: str, *args, **kwargs) -> None:
for hook in self._hooks:
getattr(hook, func_name)(self, *args, **kwargs)
def _load_accum_batch(self, data: Any, label: Any):
"""Loads a batch of data and label for gradient accumulation.
@ -77,7 +88,6 @@ class NonPipelineScheduler(BaseScheduler):
forward_only: bool = False,
return_loss: bool = True,
scale_loss: int = 1,
post_fn: Callable = None,
):
"""Trains one batch of data.
@ -89,23 +99,27 @@ class NonPipelineScheduler(BaseScheduler):
be executed.
return_loss (bool, optional): Loss will be returned if True.
scale_loss (int, optional): The scale factor for the loss.
post_fn (Callable, optional): Call back function after executing data forward output.
"""
# forward
with conditional_context(torch.no_grad(), enable=forward_only):
self._call_hooks("before_forward", data)
output = self._call_engine(engine, data)
self._call_hooks("after_forward", output)
if post_fn is not None:
post_fn(output, label)
self._call_hooks("post_helper_func", output, label)
if return_loss:
self._call_hooks("before_criterion", output, label)
loss = self._call_engine_criterion(engine, output, label)
self._call_hooks("after_criterion", loss)
loss /= scale_loss
# backward
if not forward_only:
self._call_hooks("before_backward", None, None)
engine.backward(loss)
self._call_hooks("after_backward", None)
if not return_loss:
loss = None
@ -119,7 +133,6 @@ class NonPipelineScheduler(BaseScheduler):
forward_only: bool = False,
return_loss: bool = True,
return_output_label: bool = True,
post_fn: Callable = None,
):
"""The process function that loads a batch of dataset and feeds it to the model.
The returned labels and loss will None if :attr:`return_loss` is False.
@ -131,7 +144,6 @@ class NonPipelineScheduler(BaseScheduler):
If True, the model is run for the forward pass, else back propagation will be executed.
return_loss (bool, optional): Loss will be returned if True.
return_output_label (bool, optional): Output and label will be returned if True.
post_fn (Callable, optional): Call back function after executing data forward output.
Returns:
Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None.
@ -165,7 +177,7 @@ class NonPipelineScheduler(BaseScheduler):
_data, _label = self._load_accum_batch(data, label)
_output, _loss = self._train_one_batch(
_data, _label, engine, forward_only, return_loss, self._grad_accum_size, post_fn
_data, _label, engine, forward_only, return_loss, self._grad_accum_size
)
if return_loss:

File diff suppressed because it is too large Load Diff

View File

@ -3,7 +3,7 @@
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/initialize
from typing import Callable, Iterable, Optional, Tuple
from typing import Callable, Iterable, List, Optional, Tuple
from torch import nn
from torch.nn.modules.loss import _Loss
@ -15,12 +15,13 @@ from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.core.engine import Engine
from internlm.core.gradient_handler import PipelineSharedModuleGradientHandler
from internlm.core.scheduler.no_pipeline_scheduler import NonPipelineScheduler
from internlm.core.scheduler.pipeline_scheduler import (
from internlm.core.scheduler import (
InterleavedPipelineScheduler,
NonPipelineScheduler,
PipelineScheduler,
get_tensor_shape,
SchedulerHook,
)
from internlm.core.scheduler.pipeline_scheduler import get_tensor_shape
from internlm.core.trainer import Trainer
from internlm.data.utils import unpack_data
from internlm.solver.beta2_scheduler import Beta2Scheduler
@ -36,6 +37,7 @@ def initialize_trainer(
test_dataloader: Optional[Iterable] = None,
lr_scheduler: Optional[_LRScheduler] = None,
beta2_scheduler: Optional[Beta2Scheduler] = None,
scheduler_hooks: Optional[List[SchedulerHook]] = None,
) -> Tuple[Trainer, DataLoader, DataLoader, _LRScheduler]:
"""Core function to wrap the essential training components with our functionality based on the config which is
loaded into gpc.config.
@ -92,12 +94,16 @@ def initialize_trainer(
if use_interleaved:
if isinstance(model, nn.Sequential):
model = nn.ModuleList([model])
communication_overlap = gpc.config.parallel["pipeline"].get("interleaved_overlap", False)
scheduler = InterleavedPipelineScheduler(
num_microbatches=gpc.config.NUM_MICRO_BATCHES,
num_model_chunks=gpc.config.model.num_chunks,
num_chunks=gpc.config.model.num_chunks,
dtype=gpc.config.model["dtype"],
tensor_shape=tensor_shape,
scatter_gather_tensors=scatter_gather,
scheduler_hooks=scheduler_hooks,
communication_overlap=communication_overlap,
)
else:
scheduler = PipelineScheduler(
@ -106,10 +112,13 @@ def initialize_trainer(
dtype=gpc.config.model["dtype"],
tensor_shape=tensor_shape,
scatter_gather_tensors=scatter_gather,
scheduler_hooks=scheduler_hooks,
)
else:
scheduler = NonPipelineScheduler(
data_process_func=data_fn, gradient_accumulation_size=gpc.config.data.gradient_accumulation
data_process_func=data_fn,
gradient_accumulation_size=gpc.config.data.gradient_accumulation,
scheduler_hooks=scheduler_hooks,
)
# initialize engine for trainer

View File

@ -7,40 +7,47 @@ from tqdm import tqdm
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.model.metrics import AccPerplex
from internlm.core.scheduler import SchedulerMetricHook
@contextmanager
def switch_evaluation_no_pipeline_scheduler(trainer, grad_accum_size, grad_accum_batch_size):
def switch_evaluation_no_pipeline_scheduler(trainer, grad_accum_size, grad_accum_batch_size, metric_hook_list):
if not gpc.is_using_pp():
prev_data_process_func = trainer.schedule.data_process_func
prev_grad_accum_size = trainer.schedule._grad_accum_size
prev_grad_accum_batch_size = trainer.schedule._grad_accum_batch_size
prev_metric_hooks = trainer.schedule._hooks
try:
trainer.schedule.data_process_func = None
trainer.schedule._grad_accum_size = grad_accum_size
trainer.schedule._grad_accum_batch_size = grad_accum_batch_size
trainer.schedule._hooks = metric_hook_list
yield
finally:
trainer.schedule.data_process_func = prev_data_process_func
trainer.schedule._grad_accum_size = prev_grad_accum_size
trainer.schedule._grad_accum_batch_size = prev_grad_accum_batch_size
trainer.schedule._hooks = prev_metric_hooks
@contextmanager
def switch_evaluation_pipeline_scheduler(trainer, num_microbatches, tensor_shape):
def switch_evaluation_pipeline_scheduler(trainer, num_microbatches, tensor_shape, metric_hook_list):
if gpc.is_using_pp():
pre_data_process_func = trainer.schedule.data_process_func
prev_num_microbatches = trainer.schedule.num_microbatches
prev_tensor_shape = trainer.schedule.tensor_shape
prev_metric_hooks = trainer.schedule._hooks
try:
trainer.schedule.data_process_func = None
trainer.schedule.num_microbatches = num_microbatches
trainer.schedule.tensor_shape = tensor_shape
trainer.schedule._hooks = metric_hook_list
yield
finally:
trainer.schedule.data_process_func = pre_data_process_func
trainer.schedule.num_microbatches = prev_num_microbatches
trainer.schedule.tensor_shape = prev_tensor_shape
trainer.schedule._hooks = prev_metric_hooks
def evaluate_on_val_dls(
@ -49,7 +56,6 @@ def evaluate_on_val_dls(
writer,
logger,
step_count,
tokenizer=None,
update_panel: bool = False,
):
torch.cuda.empty_cache()
@ -66,8 +72,9 @@ def evaluate_on_val_dls(
device=torch.cuda.current_device(),
tp_pg=gpc.get_group(ParallelMode.TENSOR),
dp_pg=gpc.get_group(ParallelMode.DATA),
tokenizer=tokenizer,
)
val_sche_metric_hook = SchedulerMetricHook(metric=val_metric)
val_loss = 0
val_idx = -1
for val_idx, batch in tqdm(
@ -88,10 +95,13 @@ def evaluate_on_val_dls(
)
with switch_evaluation_pipeline_scheduler(
trainer=trainer, num_microbatches=num_microbatches, tensor_shape=tensor_shape
trainer=trainer,
num_microbatches=num_microbatches,
tensor_shape=tensor_shape,
metric_hook_list=[val_sche_metric_hook],
):
_, _, loss = trainer.execute_schedule(
batch, forward_only=True, return_loss=True, return_output_label=False, post_fn=val_metric
batch, forward_only=True, return_loss=True, return_output_label=False
)
else:
total_val_bsz = len(batch[1])
@ -100,38 +110,42 @@ def evaluate_on_val_dls(
grad_accum_batch_size = data_cfg.micro_bsz
with switch_evaluation_no_pipeline_scheduler(
trainer=trainer, grad_accum_size=grad_accum_size, grad_accum_batch_size=grad_accum_batch_size
trainer=trainer,
grad_accum_size=grad_accum_size,
grad_accum_batch_size=grad_accum_batch_size,
metric_hook_list=[val_sche_metric_hook],
):
_, _, loss = trainer.execute_schedule(
batch, forward_only=True, return_loss=True, return_output_label=False, post_fn=val_metric
batch, forward_only=True, return_loss=True, return_output_label=False
)
if verbose:
val_loss += loss.item()
assert val_idx != -1
dist.barrier()
val_res = val_metric.get_metric()
val_res = val_metric.get_metric()
if verbose and len(val_dl) != 0:
val_loss = val_loss / (val_idx + 1 + 1e-6)
infos = {
"step": step_count,
f"val/{val_name}_loss": val_loss,
f"val/{val_name}_acc": val_res["acc"],
f"val/{val_name}_plex": val_res["perplexity"],
}
val_metric = {
"step": step_count,
"val_loss": val_loss,
"val_acc": val_res["acc"],
"val_perplexity": val_res["perplexity"],
}
for key, value in infos.items():
writer.add_scalar(key=key, value=value, step=step_count)
infos["step"] = step_count
if update_panel:
logger.info(
f"Validation on {val_name}: " + " ".join([f"{key}={value}" for key, value in infos.items()]),
extra=val_metric,
extra={
"step": step_count,
"val_loss": val_loss,
"val_acc": val_res["acc"],
"val_perplexity": val_res["perplexity"],
},
)
else:
logger.info(

View File

@ -16,6 +16,7 @@ import internlm
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.core.naive_amp import NaiveAMPModel
from internlm.core.scheduler import SchedulerMetricHook
from internlm.core.trainer import TrainState
from internlm.data.batch_sampler import StaticBatchSampler, get_dpsampler_dataloader
from internlm.data.collaters import jsonl_ds_collate_fn, packed_collate_fn
@ -109,12 +110,25 @@ def initialize_model():
"""
model = MODEL_INITIALIZER.get_module(module_name=gpc.config.model_type)(**(gpc.config.model))
model = NaiveAMPModel(
model=model,
output_to_fp32=is_no_pp_or_last_stage(),
dtype=gpc.config.model.get("dtype", torch.half),
sync_buffer=False,
)
if isinstance(model, nn.ModuleList):
model = nn.ModuleList(
[
NaiveAMPModel(
model=_m,
output_to_fp32=False, # manually controlled by interleaved pipleline scheduler
dtype=gpc.config.model.get("dtype", torch.half),
sync_buffer=False,
)
for _m in model
]
)
else:
model = NaiveAMPModel(
model=model,
output_to_fp32=is_no_pp_or_last_stage(),
dtype=gpc.config.model.get("dtype", torch.half),
sync_buffer=False,
)
# This sync is very important, cause the model weights kept in optimizer are copied
# from the origin parameters in the memory, so we should make sure the dp sync
@ -500,19 +514,6 @@ def main(args):
if load_optimizer:
load_optimizer_checkpoint(load_resume_ckpt_folder, optimizer)
# initialize trainer
trainer, train_dl, _, _ = internlm.initialize_trainer(
model=model,
optimizer=optimizer,
criterion=criterion,
train_dataloader=train_dl,
lr_scheduler=lr_scheduler,
beta2_scheduler=beta2_scheduler,
)
# initialize the batch skipper
batch_skipper = BatchSkipper(skip_batches)
# initialize metric for calculating accuracy and perplexity
metric = AccPerplex(
device=torch.cuda.current_device(),
@ -521,6 +522,27 @@ def main(args):
dataset_types=dataset_types,
)
# initialize trainer
scheduler_hooks = [
SchedulerMetricHook(
metric=metric,
skip=gpc.is_using_pp() and gpc.config.parallel["pipeline"].get("interleaved_overlap", False),
),
]
trainer, train_dl, _, _ = internlm.initialize_trainer(
model=model,
optimizer=optimizer,
criterion=criterion,
train_dataloader=train_dl,
lr_scheduler=lr_scheduler,
beta2_scheduler=beta2_scheduler,
scheduler_hooks=scheduler_hooks,
)
# initialize the batch skipper
batch_skipper = BatchSkipper(skip_batches)
trainer.train()
# transfer the train data loader into train data iterator
@ -558,9 +580,7 @@ def main(args):
# do forward and backward
timer("fwd-bwd").start()
_, _, loss = trainer.execute_schedule(
batch, forward_only=False, return_loss=True, return_output_label=False, post_fn=metric
)
_, _, loss = trainer.execute_schedule(batch, forward_only=False, return_loss=True, return_output_label=False)
timer("fwd-bwd").stop()
# update parameters, and returns (success_update, grad_norm)