mirror of https://github.com/hpcaitech/ColossalAI
[doc] improved docstring and assertion messages for the engine module (#871)
parent
1c34382678
commit
11f54c7b6b
|
@ -1,11 +1,9 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
# -*- encoding: utf-8 -*-
|
# -*- encoding: utf-8 -*-
|
||||||
|
|
||||||
from asyncio.log import logger
|
|
||||||
from typing import List, Iterable
|
from typing import List, Iterable
|
||||||
from torch.nn import Module
|
from torch.nn import Module
|
||||||
from torch.nn.modules.loss import _Loss
|
from torch.nn.modules.loss import _Loss
|
||||||
from torch.optim import Optimizer
|
|
||||||
|
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
@ -23,7 +21,7 @@ class Engine:
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model (``torch.nn.Module``): The neural network model.
|
model (``torch.nn.Module``): The neural network model.
|
||||||
optimizer (``torch.optim.Optimizer``): Optimizer for updating the parameters.
|
optimizer (``colossalai.nn.optimizer.ColossalaiOptimizer``): Optimizer for updating the parameters.
|
||||||
criterion (``torch.nn.modules.loss._Loss``, optional): Loss function for calculating loss.
|
criterion (``torch.nn.modules.loss._Loss``, optional): Loss function for calculating loss.
|
||||||
gradient_handlers (List[``BaseGradientHandler``], optional): A list of gradient handler used in backward.
|
gradient_handlers (List[``BaseGradientHandler``], optional): A list of gradient handler used in backward.
|
||||||
clip_grad_norm (float, optional): The norm of gradient clipping.
|
clip_grad_norm (float, optional): The norm of gradient clipping.
|
||||||
|
@ -57,7 +55,7 @@ class Engine:
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
model: Module,
|
model: Module,
|
||||||
optimizer: Optimizer,
|
optimizer: "ColossalaiOptimizer",
|
||||||
criterion: Optional[_Loss] = None,
|
criterion: Optional[_Loss] = None,
|
||||||
gradient_handlers: Optional[List[BaseGradientHandler]] = None,
|
gradient_handlers: Optional[List[BaseGradientHandler]] = None,
|
||||||
clip_grad_norm: float = 0.0,
|
clip_grad_norm: float = 0.0,
|
||||||
|
@ -87,6 +85,8 @@ class Engine:
|
||||||
|
|
||||||
# build schedule
|
# build schedule
|
||||||
if schedule:
|
if schedule:
|
||||||
|
assert isinstance(schedule, BaseSchedule), \
|
||||||
|
f'expected schedule to be of type BaseSchedule, but got {type(schedule)}'
|
||||||
self._schedule = schedule
|
self._schedule = schedule
|
||||||
else:
|
else:
|
||||||
self._schedule = NonPipelineSchedule()
|
self._schedule = NonPipelineSchedule()
|
||||||
|
|
|
@ -1,9 +1,10 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
# -*- encoding: utf-8 -*-
|
# -*- encoding: utf-8 -*-
|
||||||
|
|
||||||
|
from typing import Union
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from typing import Iterable, Any
|
from typing import Iterable, Any, Tuple
|
||||||
from colossalai.nn.optimizer import ColossalaiOptimizer
|
from colossalai.nn.optimizer import ColossalaiOptimizer
|
||||||
from torch.nn.parallel.distributed import DistributedDataParallel
|
from torch.nn.parallel.distributed import DistributedDataParallel
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
|
@ -33,24 +34,54 @@ class GradAccumOptimizer(ColossalaiOptimizer):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.is_torch_ddp = isinstance(self.model, DistributedDataParallel)
|
self.is_torch_ddp = isinstance(self.model, DistributedDataParallel)
|
||||||
|
|
||||||
def zero_grad(self, *args, **kwargs):
|
def zero_grad(self, *args, **kwargs) -> None:
|
||||||
|
"""
|
||||||
|
Set all gradients to zero.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
*args: positional arguments for the optimizer wrapped
|
||||||
|
**kwargs: keyword arguments for the optimizer wrapped
|
||||||
|
"""
|
||||||
|
|
||||||
if self.accumulate_step == 0:
|
if self.accumulate_step == 0:
|
||||||
self.optim.zero_grad(*args, **kwargs)
|
self.optim.zero_grad(*args, **kwargs)
|
||||||
|
|
||||||
def step(self, *args, **kwargs):
|
def step(self, *args, **kwargs) -> None:
|
||||||
|
"""
|
||||||
|
Update the model parameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
*args: positional arguments for the optimizer wrapped
|
||||||
|
**kwargs: keyword arguments for the optimizer wrapped
|
||||||
|
"""
|
||||||
|
|
||||||
if self.accumulate_step < self.accumulate_size:
|
if self.accumulate_step < self.accumulate_size:
|
||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
self.accumulate_step = 0
|
self.accumulate_step = 0
|
||||||
return self.optim.step(*args, **kwargs)
|
return self.optim.step(*args, **kwargs)
|
||||||
|
|
||||||
def clip_grad_norm(self, model: nn.Module, max_norm: float):
|
def clip_grad_norm(self, model: nn.Module, max_norm: float) -> None:
|
||||||
|
"""
|
||||||
|
Clip gradients by norm.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (:class:`torch.nn.Module`): a torch module instance
|
||||||
|
max_norm (float): the max norm for gradient clipping
|
||||||
|
"""
|
||||||
|
|
||||||
if self.accumulate_step < self.accumulate_size:
|
if self.accumulate_step < self.accumulate_size:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
self.optim.clip_grad_norm(model, max_norm)
|
self.optim.clip_grad_norm(model, max_norm)
|
||||||
|
|
||||||
def backward(self, loss: Tensor):
|
def backward(self, loss: Tensor) -> None:
|
||||||
|
"""Execute backward pass.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
loss (:class:`torch.Tensor`): the loss value.
|
||||||
|
"""
|
||||||
|
|
||||||
self.accumulate_step += 1
|
self.accumulate_step += 1
|
||||||
|
|
||||||
if self.is_torch_ddp:
|
if self.is_torch_ddp:
|
||||||
|
@ -62,7 +93,14 @@ class GradAccumOptimizer(ColossalaiOptimizer):
|
||||||
scaled_loss = loss / self.accumulate_size
|
scaled_loss = loss / self.accumulate_size
|
||||||
self.optim.backward(scaled_loss)
|
self.optim.backward(scaled_loss)
|
||||||
|
|
||||||
def backward_by_grad(self, tensor: Tensor, grad: Tensor):
|
def backward_by_grad(self, tensor: Tensor, grad: Tensor) -> None:
|
||||||
|
"""Execute backward pass given the gradients of the output.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
loss (:class:`torch.Tensor`): the loss value.
|
||||||
|
grad (:class:`torch.Tensor`): the output gradient.
|
||||||
|
"""
|
||||||
|
|
||||||
self.accumulate_step += 1
|
self.accumulate_step += 1
|
||||||
no_sync = self.is_torch_ddp and self.accumulate_step < self.accumulate_size
|
no_sync = self.is_torch_ddp and self.accumulate_step < self.accumulate_size
|
||||||
|
|
||||||
|
@ -84,7 +122,7 @@ class GradAccumDataloader:
|
||||||
(e.g. Dali dataloader), this class will automatically consume (load data for nothing) the remaining 2 batches.
|
(e.g. Dali dataloader), this class will automatically consume (load data for nothing) the remaining 2 batches.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
optim (``Iterable``): Your dataloader object for gradient accumulation.
|
dataloader (``Iterable``): Your dataloader object for gradient accumulation.
|
||||||
accumulate_size (int): The number of steps to accumulate gradients.
|
accumulate_size (int): The number of steps to accumulate gradients.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -96,15 +134,15 @@ class GradAccumDataloader:
|
||||||
def __getattr__(self, __name: str) -> Any:
|
def __getattr__(self, __name: str) -> Any:
|
||||||
return getattr(self.dataloader, __name)
|
return getattr(self.dataloader, __name)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self) -> int:
|
||||||
return self.steps_per_epoch
|
return self.steps_per_epoch
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self) -> Iterable:
|
||||||
self._cur_step = 0
|
self._cur_step = 0
|
||||||
self._dataiter = iter(self.dataloader)
|
self._dataiter = iter(self.dataloader)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __next__(self) -> Any:
|
def __next__(self) -> Union[Tensor, Tuple[Tensor]]:
|
||||||
if self._cur_step < self.steps_per_epoch:
|
if self._cur_step < self.steps_per_epoch:
|
||||||
self._cur_step += 1
|
self._cur_step += 1
|
||||||
|
|
||||||
|
@ -137,13 +175,30 @@ class GradAccumLrSchedulerByStep(_LRScheduler):
|
||||||
self.accumulate_step = 0
|
self.accumulate_step = 0
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def compute_effective_steps_per_epoch(dataloader: Iterable, accumulate_size: int):
|
def compute_effective_steps_per_epoch(dataloader: Iterable, accumulate_size: int) -> int:
|
||||||
|
"""
|
||||||
|
Computes the number of effective training iterations. An effective iteration is defined
|
||||||
|
as the the aggregation of <accumulate_size> iterations. For examples, if accumulate_size = 4,
|
||||||
|
then 4 iterations are considered as one effective iteration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataloader (``Iterable``): Your dataloader object for gradient accumulation.
|
||||||
|
accumulate_size (int): The number of steps to accumulate gradients.
|
||||||
|
|
||||||
|
"""
|
||||||
return len(dataloader) // accumulate_size
|
return len(dataloader) // accumulate_size
|
||||||
|
|
||||||
def __getattr__(self, __name: str) -> Any:
|
def __getattr__(self, __name: str) -> Any:
|
||||||
return getattr(self.lr_scheduler, __name)
|
return getattr(self.lr_scheduler, __name)
|
||||||
|
|
||||||
def step(self, *args, **kwargs):
|
def step(self, *args, **kwargs) -> None:
|
||||||
|
"""
|
||||||
|
Update the learning rate.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
*args: positional arguments for the lr scheduler wrapped.
|
||||||
|
**kwargs: keyword arguments for the lr scheduler wrapped.
|
||||||
|
"""
|
||||||
self.accumulate_step += 1
|
self.accumulate_step += 1
|
||||||
if self.accumulate_step < self.accumulate_size:
|
if self.accumulate_step < self.accumulate_size:
|
||||||
pass
|
pass
|
||||||
|
@ -151,19 +206,52 @@ class GradAccumLrSchedulerByStep(_LRScheduler):
|
||||||
self.accumulate_step = 0
|
self.accumulate_step = 0
|
||||||
self.lr_scheduler.step(*args, **kwargs)
|
self.lr_scheduler.step(*args, **kwargs)
|
||||||
|
|
||||||
def get_lr(self):
|
def get_lr(self) -> Tensor:
|
||||||
|
"""
|
||||||
|
Compute the next learning rate.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: the upcoming learning rate.
|
||||||
|
"""
|
||||||
|
|
||||||
return self.lr_scheduler.get_lr()
|
return self.lr_scheduler.get_lr()
|
||||||
|
|
||||||
def get_last_lr(self):
|
def get_last_lr(self) -> Tensor:
|
||||||
|
"""
|
||||||
|
Returns the current learning rate.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: the current learning rate.
|
||||||
|
"""
|
||||||
|
|
||||||
return self.lr_scheduler.get_last_lr()
|
return self.lr_scheduler.get_last_lr()
|
||||||
|
|
||||||
def print_lr(self, *args, **kwargs):
|
def print_lr(self, *args, **kwargs) -> None:
|
||||||
|
"""
|
||||||
|
Print he learning rate.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
*args: positional arguments for the lr scheduler wrapped.
|
||||||
|
**kwargs: keyword arguments for the lr scheduler wrapped.
|
||||||
|
"""
|
||||||
self.lr_scheduler.print_lr(*args, **kwargs)
|
self.lr_scheduler.print_lr(*args, **kwargs)
|
||||||
|
|
||||||
def state_dict(self) -> dict:
|
def state_dict(self) -> dict:
|
||||||
|
"""
|
||||||
|
Returns the states of the lr scheduler as dictionary.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: the states of the lr scheduler.
|
||||||
|
"""
|
||||||
return self.lr_scheduler.state_dict()
|
return self.lr_scheduler.state_dict()
|
||||||
|
|
||||||
def load_state_dict(self, state_dict: dict) -> None:
|
def load_state_dict(self, state_dict: dict) -> None:
|
||||||
|
"""
|
||||||
|
Load the states of the lr scheduler from a dictionary object.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: the states of the lr scheduler.
|
||||||
|
"""
|
||||||
self.lr_scheduler.load_state_dict(state_dict)
|
self.lr_scheduler.load_state_dict(state_dict)
|
||||||
|
|
||||||
|
|
||||||
|
@ -188,7 +276,11 @@ class GradAccumGradientHandler:
|
||||||
self.accumulate_size = accumulate_size
|
self.accumulate_size = accumulate_size
|
||||||
self.accumulate_step = 0
|
self.accumulate_step = 0
|
||||||
|
|
||||||
def handle_gradient(self):
|
def handle_gradient(self) -> None:
|
||||||
|
"""
|
||||||
|
Handle gradients reduction only in the last gradient accumulation step.
|
||||||
|
"""
|
||||||
|
|
||||||
self.accumulate_step += 1
|
self.accumulate_step += 1
|
||||||
if self.accumulate_step < self.accumulate_size:
|
if self.accumulate_step < self.accumulate_size:
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -12,6 +12,10 @@ class DataParallelGradientHandler(BaseGradientHandler):
|
||||||
:func:`handle_gradient` among a data parallel group.
|
:func:`handle_gradient` among a data parallel group.
|
||||||
For better performance, it bucketizes the gradients of all parameters that are
|
For better performance, it bucketizes the gradients of all parameters that are
|
||||||
the same type to improve the efficiency of communication.
|
the same type to improve the efficiency of communication.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (Module): Model where the gradients accumulate.
|
||||||
|
optimizer (Optimizer): Optimizer for updating the parameters.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def handle_gradient(self):
|
def handle_gradient(self):
|
||||||
|
|
|
@ -14,6 +14,10 @@ class MoeGradientHandler(BaseGradientHandler):
|
||||||
:func:`handle_gradient` among a data parallel group.
|
:func:`handle_gradient` among a data parallel group.
|
||||||
For better performance, it bucketizes the gradients of all parameters that are
|
For better performance, it bucketizes the gradients of all parameters that are
|
||||||
the same type to improve the efficiency of communication.
|
the same type to improve the efficiency of communication.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (Module): Model where the gradients accumulate.
|
||||||
|
optimizer (Optimizer): Optimizer for updating the parameters.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model, optimizer=None):
|
def __init__(self, model, optimizer=None):
|
||||||
|
@ -29,7 +33,6 @@ class MoeGradientHandler(BaseGradientHandler):
|
||||||
if global_data > 1:
|
if global_data > 1:
|
||||||
epsize_param_dict = get_moe_epsize_param_dict(self._model)
|
epsize_param_dict = get_moe_epsize_param_dict(self._model)
|
||||||
|
|
||||||
|
|
||||||
# epsize is 1, indicating the params are replicated among processes in data parallelism
|
# epsize is 1, indicating the params are replicated among processes in data parallelism
|
||||||
# use the ParallelMode.DATA to get data parallel group
|
# use the ParallelMode.DATA to get data parallel group
|
||||||
# reduce gradients for all parameters in data parallelism
|
# reduce gradients for all parameters in data parallelism
|
||||||
|
|
|
@ -18,6 +18,10 @@ class PipelineSharedModuleGradientHandler(BaseGradientHandler):
|
||||||
:func:`handle_gradient` among all sub pipeline parallel groups.
|
:func:`handle_gradient` among all sub pipeline parallel groups.
|
||||||
For better performance, it bucketizes the gradients of all parameters that are
|
For better performance, it bucketizes the gradients of all parameters that are
|
||||||
the same type to improve the efficiency of communication.
|
the same type to improve the efficiency of communication.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (Module): Model where the gradients accumulate.
|
||||||
|
optimizer (Optimizer): Optimizer for updating the parameters.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def handle_gradient(self):
|
def handle_gradient(self):
|
||||||
|
|
|
@ -12,6 +12,10 @@ class SequenceParallelGradientHandler(BaseGradientHandler):
|
||||||
:func:`handle_gradient` among a data parallel group.
|
:func:`handle_gradient` among a data parallel group.
|
||||||
For better performance, it bucketizes the gradients of all parameters that are
|
For better performance, it bucketizes the gradients of all parameters that are
|
||||||
the same type to improve the efficiency of communication.
|
the same type to improve the efficiency of communication.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (Module): Model where the gradients accumulate.
|
||||||
|
optimizer (Optimizer): Optimizer for updating the parameters.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def handle_gradient(self):
|
def handle_gradient(self):
|
||||||
|
|
|
@ -8,6 +8,10 @@ class ZeROGradientHandler(BaseGradientHandler):
|
||||||
A all-reduce collective communication will be operated in
|
A all-reduce collective communication will be operated in
|
||||||
:func:`handle_gradient` among a data parallel group.
|
:func:`handle_gradient` among a data parallel group.
|
||||||
This class is specialized with ZeRO optimization.
|
This class is specialized with ZeRO optimization.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (Module): Model where the gradients accumulate.
|
||||||
|
optimizer (Optimizer): Optimizer for updating the parameters.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def handle_gradient(self):
|
def handle_gradient(self):
|
||||||
|
|
|
@ -28,7 +28,11 @@ class BaseParamHookMgr(object):
|
||||||
handle = p.register_hook(functools.partial(hook_call, p))
|
handle = p.register_hook(functools.partial(hook_call, p))
|
||||||
p._base_param_hook = handle
|
p._base_param_hook = handle
|
||||||
|
|
||||||
def remove_hooks(self):
|
def remove_hooks(self) -> None:
|
||||||
|
"""
|
||||||
|
Remove hooks from model parameters.
|
||||||
|
"""
|
||||||
|
|
||||||
for p in self._param_list:
|
for p in self._param_list:
|
||||||
if p.requires_grad and hasattr(p, '_base_param_hook'):
|
if p.requires_grad and hasattr(p, '_base_param_hook'):
|
||||||
p._base_param_hook.remove()
|
p._base_param_hook.remove()
|
||||||
|
|
|
@ -81,6 +81,9 @@ class PipelineSchedule(BaseSchedule):
|
||||||
tensor_shape: Union[torch.Size, List[int], Tuple[int]] = None,
|
tensor_shape: Union[torch.Size, List[int], Tuple[int]] = None,
|
||||||
scatter_gather_tensors: bool = False):
|
scatter_gather_tensors: bool = False):
|
||||||
super().__init__(batch_data_process_func=batch_data_process_func)
|
super().__init__(batch_data_process_func=batch_data_process_func)
|
||||||
|
|
||||||
|
assert num_microbatches > 0, f'expected num_microbatches to be larger then 1, but got {num_microbatches}'
|
||||||
|
|
||||||
self.num_microbatches = num_microbatches
|
self.num_microbatches = num_microbatches
|
||||||
self.dtype = torch.float
|
self.dtype = torch.float
|
||||||
self.tensor_shape = tensor_shape
|
self.tensor_shape = tensor_shape
|
||||||
|
@ -150,7 +153,7 @@ class PipelineSchedule(BaseSchedule):
|
||||||
else:
|
else:
|
||||||
return model(input_tensor, **batch_data)
|
return model(input_tensor, **batch_data)
|
||||||
|
|
||||||
def forward_step(self, engine, input_tensor, return_tensors, return_output_label=True, accum_loss=None):
|
def _forward_step(self, engine, input_tensor, return_tensors, return_output_label=True, accum_loss=None):
|
||||||
"""Forward step for passed-in model. If it is the first stage, the input tensor
|
"""Forward step for passed-in model. If it is the first stage, the input tensor
|
||||||
is obtained from data_iterator, otherwise the passed-in input_tensor is used.
|
is obtained from data_iterator, otherwise the passed-in input_tensor is used.
|
||||||
Returns output tensor. This is a helper function and can be ignored by users.
|
Returns output tensor. This is a helper function and can be ignored by users.
|
||||||
|
@ -186,7 +189,7 @@ class PipelineSchedule(BaseSchedule):
|
||||||
)
|
)
|
||||||
return output_tensor
|
return output_tensor
|
||||||
|
|
||||||
def backward_step(self, engine, input_tensor, output_tensor, output_tensor_grad):
|
def _backward_step(self, engine, input_tensor, output_tensor, output_tensor_grad):
|
||||||
"""Backward step through the passed-in output tensor. If it is the last stage, the
|
"""Backward step through the passed-in output tensor. If it is the last stage, the
|
||||||
output_tensor_grad is None, otherwise it is the gradients with respect to stage's output tensor.
|
output_tensor_grad is None, otherwise it is the gradients with respect to stage's output tensor.
|
||||||
Returns the gradients with respect to the input tensor (None if first stage).
|
Returns the gradients with respect to the input tensor (None if first stage).
|
||||||
|
@ -267,7 +270,7 @@ class PipelineSchedule(BaseSchedule):
|
||||||
input_tensor = comm.recv_forward(ft_shape,
|
input_tensor = comm.recv_forward(ft_shape,
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
scatter_gather_tensors=self.scatter_gather_tensors)
|
scatter_gather_tensors=self.scatter_gather_tensors)
|
||||||
output_tensor = self.forward_step(engine,
|
output_tensor = self._forward_step(engine,
|
||||||
input_tensor,
|
input_tensor,
|
||||||
return_tensors,
|
return_tensors,
|
||||||
return_output_label=return_output_label,
|
return_output_label=return_output_label,
|
||||||
|
@ -295,7 +298,7 @@ class PipelineSchedule(BaseSchedule):
|
||||||
for i in range(num_microbatches_remaining):
|
for i in range(num_microbatches_remaining):
|
||||||
last_iteration = (i == (num_microbatches_remaining - 1))
|
last_iteration = (i == (num_microbatches_remaining - 1))
|
||||||
|
|
||||||
output_tensor = self.forward_step(engine,
|
output_tensor = self._forward_step(engine,
|
||||||
input_tensor,
|
input_tensor,
|
||||||
return_tensors,
|
return_tensors,
|
||||||
return_output_label=return_output_label,
|
return_output_label=return_output_label,
|
||||||
|
@ -323,7 +326,7 @@ class PipelineSchedule(BaseSchedule):
|
||||||
input_tensor = input_tensors.pop(0)
|
input_tensor = input_tensors.pop(0)
|
||||||
output_tensor = output_tensors.pop(0)
|
output_tensor = output_tensors.pop(0)
|
||||||
|
|
||||||
input_tensor_grad = self.backward_step(engine, input_tensor, output_tensor, output_tensor_grad)
|
input_tensor_grad = self._backward_step(engine, input_tensor, output_tensor, output_tensor_grad)
|
||||||
|
|
||||||
if last_iteration:
|
if last_iteration:
|
||||||
input_tensor = None
|
input_tensor = None
|
||||||
|
@ -344,7 +347,7 @@ class PipelineSchedule(BaseSchedule):
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
scatter_gather_tensors=self.scatter_gather_tensors)
|
scatter_gather_tensors=self.scatter_gather_tensors)
|
||||||
|
|
||||||
input_tensor_grad = self.backward_step(engine, input_tensor, output_tensor, output_tensor_grad)
|
input_tensor_grad = self._backward_step(engine, input_tensor, output_tensor, output_tensor_grad)
|
||||||
|
|
||||||
comm.send_backward(input_tensor_grad, scatter_gather_tensors=self.scatter_gather_tensors)
|
comm.send_backward(input_tensor_grad, scatter_gather_tensors=self.scatter_gather_tensors)
|
||||||
|
|
||||||
|
@ -358,8 +361,8 @@ class PipelineSchedule(BaseSchedule):
|
||||||
class InterleavedPipelineSchedule(PipelineSchedule):
|
class InterleavedPipelineSchedule(PipelineSchedule):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
num_microbatches,
|
num_microbatches: int,
|
||||||
num_model_chunks,
|
num_model_chunks: int,
|
||||||
batch_data_process_func: Callable = None,
|
batch_data_process_func: Callable = None,
|
||||||
tensor_shape: Union[torch.Size, List[int], Tuple[int]] = None,
|
tensor_shape: Union[torch.Size, List[int], Tuple[int]] = None,
|
||||||
scatter_gather_tensors: bool = False):
|
scatter_gather_tensors: bool = False):
|
||||||
|
@ -378,6 +381,8 @@ class InterleavedPipelineSchedule(PipelineSchedule):
|
||||||
"""
|
"""
|
||||||
assert num_microbatches % gpc.get_world_size(ParallelMode.PIPELINE) == 0, \
|
assert num_microbatches % gpc.get_world_size(ParallelMode.PIPELINE) == 0, \
|
||||||
'num_microbatches must be an integer multiple of pipeline parallel world size'
|
'num_microbatches must be an integer multiple of pipeline parallel world size'
|
||||||
|
assert isinstance(num_model_chunks, int) and num_model_chunks > 0, \
|
||||||
|
f'expected num_model_chunks to be an integer and larger than 0, but got {num_model_chunks}'
|
||||||
super().__init__(num_microbatches,
|
super().__init__(num_microbatches,
|
||||||
batch_data_process_func=batch_data_process_func,
|
batch_data_process_func=batch_data_process_func,
|
||||||
tensor_shape=tensor_shape,
|
tensor_shape=tensor_shape,
|
||||||
|
@ -409,7 +414,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
|
||||||
self.microbatch_offset[model_chunk_id] += self.microbatch_size
|
self.microbatch_offset[model_chunk_id] += self.microbatch_size
|
||||||
return self._move_to_device(data), self._move_to_device(label)
|
return self._move_to_device(data), self._move_to_device(label)
|
||||||
|
|
||||||
def forward_step(self,
|
def _forward_step(self,
|
||||||
engine,
|
engine,
|
||||||
model_chunk_id,
|
model_chunk_id,
|
||||||
input_tensor,
|
input_tensor,
|
||||||
|
@ -522,7 +527,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
|
||||||
model_chunk_id = (num_model_chunks - model_chunk_id - 1)
|
model_chunk_id = (num_model_chunks - model_chunk_id - 1)
|
||||||
return model_chunk_id
|
return model_chunk_id
|
||||||
|
|
||||||
def forward_step_helper(microbatch_id):
|
def _forward_step_helper(microbatch_id):
|
||||||
"""Helper method to run forward step with model split into chunks
|
"""Helper method to run forward step with model split into chunks
|
||||||
(run set_virtual_pipeline_model_parallel_rank() before calling
|
(run set_virtual_pipeline_model_parallel_rank() before calling
|
||||||
forward_step())."""
|
forward_step())."""
|
||||||
|
@ -535,7 +540,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
|
||||||
len(output_tensors[model_chunk_id]):
|
len(output_tensors[model_chunk_id]):
|
||||||
input_tensors[model_chunk_id].append(None)
|
input_tensors[model_chunk_id].append(None)
|
||||||
input_tensor = input_tensors[model_chunk_id][-1]
|
input_tensor = input_tensors[model_chunk_id][-1]
|
||||||
output_tensor = self.forward_step(engine,
|
output_tensor = self._forward_step(engine,
|
||||||
model_chunk_id,
|
model_chunk_id,
|
||||||
input_tensor,
|
input_tensor,
|
||||||
return_tensors,
|
return_tensors,
|
||||||
|
@ -550,7 +555,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
|
||||||
|
|
||||||
return output_tensor
|
return output_tensor
|
||||||
|
|
||||||
def backward_step_helper(microbatch_id):
|
def _backward_step_helper(microbatch_id):
|
||||||
"""Helper method to run backward step with model split into chunks
|
"""Helper method to run backward step with model split into chunks
|
||||||
(run set_virtual_pipeline_model_parallel_rank() before calling
|
(run set_virtual_pipeline_model_parallel_rank() before calling
|
||||||
backward_step())."""
|
backward_step())."""
|
||||||
|
@ -563,7 +568,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
|
||||||
input_tensor = input_tensors[model_chunk_id].pop(0)
|
input_tensor = input_tensors[model_chunk_id].pop(0)
|
||||||
output_tensor = output_tensors[model_chunk_id].pop(0)
|
output_tensor = output_tensors[model_chunk_id].pop(0)
|
||||||
output_tensor_grad = output_tensor_grads[model_chunk_id].pop(0)
|
output_tensor_grad = output_tensor_grads[model_chunk_id].pop(0)
|
||||||
input_tensor_grad = self.backward_step(engine, input_tensor, output_tensor, output_tensor_grad)
|
input_tensor_grad = self._backward_step(engine, input_tensor, output_tensor, output_tensor_grad)
|
||||||
|
|
||||||
return input_tensor_grad
|
return input_tensor_grad
|
||||||
|
|
||||||
|
@ -578,7 +583,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
|
||||||
|
|
||||||
for k in range(num_warmup_microbatches):
|
for k in range(num_warmup_microbatches):
|
||||||
model_chunk_id = get_model_chunk_id(k, forward=True)
|
model_chunk_id = get_model_chunk_id(k, forward=True)
|
||||||
output_tensor = forward_step_helper(k)
|
output_tensor = _forward_step_helper(k)
|
||||||
if not gpc.is_pipeline_last_stage():
|
if not gpc.is_pipeline_last_stage():
|
||||||
output_tensor_shapes[model_chunk_id] = output_tensor.shape
|
output_tensor_shapes[model_chunk_id] = output_tensor.shape
|
||||||
send_tensor_shape_flags[model_chunk_id] = comm.send_tensor_meta(output_tensor,
|
send_tensor_shape_flags[model_chunk_id] = comm.send_tensor_meta(output_tensor,
|
||||||
|
@ -633,11 +638,11 @@ class InterleavedPipelineSchedule(PipelineSchedule):
|
||||||
for k in range(num_microbatches_remaining):
|
for k in range(num_microbatches_remaining):
|
||||||
# Forward pass.
|
# Forward pass.
|
||||||
forward_k = k + num_warmup_microbatches
|
forward_k = k + num_warmup_microbatches
|
||||||
output_tensor = forward_step_helper(forward_k)
|
output_tensor = _forward_step_helper(forward_k)
|
||||||
|
|
||||||
# Backward pass.
|
# Backward pass.
|
||||||
backward_k = k
|
backward_k = k
|
||||||
input_tensor_grad = backward_step_helper(backward_k)
|
input_tensor_grad = _backward_step_helper(backward_k)
|
||||||
|
|
||||||
# Send output_tensor and input_tensor_grad, receive input_tensor
|
# Send output_tensor and input_tensor_grad, receive input_tensor
|
||||||
# and output_tensor_grad.
|
# and output_tensor_grad.
|
||||||
|
@ -708,7 +713,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
|
||||||
comm.recv_backward(output_tensor_shapes[num_model_chunks - 1],
|
comm.recv_backward(output_tensor_shapes[num_model_chunks - 1],
|
||||||
scatter_gather_tensors=self.scatter_gather_tensors))
|
scatter_gather_tensors=self.scatter_gather_tensors))
|
||||||
for k in range(num_microbatches_remaining, num_microbatches):
|
for k in range(num_microbatches_remaining, num_microbatches):
|
||||||
input_tensor_grad = backward_step_helper(k)
|
input_tensor_grad = _backward_step_helper(k)
|
||||||
next_backward_model_chunk_id = get_model_chunk_id(k + 1, forward=False)
|
next_backward_model_chunk_id = get_model_chunk_id(k + 1, forward=False)
|
||||||
recv_next = True
|
recv_next = True
|
||||||
if gpc.is_pipeline_last_stage(ignore_virtual=True):
|
if gpc.is_pipeline_last_stage(ignore_virtual=True):
|
||||||
|
|
Loading…
Reference in New Issue