[doc] improved docstring and assertion messages for the engine module (#871)

pull/874/head^2
Frank Lee 2022-04-26 10:00:18 +08:00 committed by GitHub
parent 1c34382678
commit 11f54c7b6b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 180 additions and 60 deletions

View File

@ -1,11 +1,9 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from asyncio.log import logger
from typing import List, Iterable
from torch.nn import Module
from torch.nn.modules.loss import _Loss
from torch.optim import Optimizer
from colossalai.logging import get_dist_logger
from torch import Tensor
@ -23,7 +21,7 @@ class Engine:
Args:
model (``torch.nn.Module``): The neural network model.
optimizer (``torch.optim.Optimizer``): Optimizer for updating the parameters.
optimizer (``colossalai.nn.optimizer.ColossalaiOptimizer``): Optimizer for updating the parameters.
criterion (``torch.nn.modules.loss._Loss``, optional): Loss function for calculating loss.
gradient_handlers (List[``BaseGradientHandler``], optional): A list of gradient handler used in backward.
clip_grad_norm (float, optional): The norm of gradient clipping.
@ -57,7 +55,7 @@ class Engine:
def __init__(self,
model: Module,
optimizer: Optimizer,
optimizer: "ColossalaiOptimizer",
criterion: Optional[_Loss] = None,
gradient_handlers: Optional[List[BaseGradientHandler]] = None,
clip_grad_norm: float = 0.0,
@ -84,9 +82,11 @@ class Engine:
self._ophook_list = []
else:
self._ophook_list = ophook_list
# build schedule
if schedule:
assert isinstance(schedule, BaseSchedule), \
f'expected schedule to be of type BaseSchedule, but got {type(schedule)}'
self._schedule = schedule
else:
self._schedule = NonPipelineSchedule()
@ -187,7 +187,7 @@ class Engine:
"""
for handler in self._gradient_handlers:
handler.handle_gradient()
def execute_schedule(self, data_iter: Iterable, **kwargs):
"""Run the forward, loss computation, and backward for the model.
Returns a tuple of (output, label, loss).

View File

@ -1,9 +1,10 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from typing import Union
import torch.nn as nn
from torch import Tensor
from typing import Iterable, Any
from typing import Iterable, Any, Tuple
from colossalai.nn.optimizer import ColossalaiOptimizer
from torch.nn.parallel.distributed import DistributedDataParallel
from torch.optim import Optimizer
@ -33,24 +34,54 @@ class GradAccumOptimizer(ColossalaiOptimizer):
self.model = model
self.is_torch_ddp = isinstance(self.model, DistributedDataParallel)
def zero_grad(self, *args, **kwargs):
def zero_grad(self, *args, **kwargs) -> None:
"""
Set all gradients to zero.
Args:
*args: positional arguments for the optimizer wrapped
**kwargs: keyword arguments for the optimizer wrapped
"""
if self.accumulate_step == 0:
self.optim.zero_grad(*args, **kwargs)
def step(self, *args, **kwargs):
def step(self, *args, **kwargs) -> None:
"""
Update the model parameters.
Args:
*args: positional arguments for the optimizer wrapped
**kwargs: keyword arguments for the optimizer wrapped
"""
if self.accumulate_step < self.accumulate_size:
return None
else:
self.accumulate_step = 0
return self.optim.step(*args, **kwargs)
def clip_grad_norm(self, model: nn.Module, max_norm: float):
def clip_grad_norm(self, model: nn.Module, max_norm: float) -> None:
"""
Clip gradients by norm.
Args:
model (:class:`torch.nn.Module`): a torch module instance
max_norm (float): the max norm for gradient clipping
"""
if self.accumulate_step < self.accumulate_size:
pass
else:
self.optim.clip_grad_norm(model, max_norm)
def backward(self, loss: Tensor):
def backward(self, loss: Tensor) -> None:
"""Execute backward pass.
Args:
loss (:class:`torch.Tensor`): the loss value.
"""
self.accumulate_step += 1
if self.is_torch_ddp:
@ -62,7 +93,14 @@ class GradAccumOptimizer(ColossalaiOptimizer):
scaled_loss = loss / self.accumulate_size
self.optim.backward(scaled_loss)
def backward_by_grad(self, tensor: Tensor, grad: Tensor):
def backward_by_grad(self, tensor: Tensor, grad: Tensor) -> None:
"""Execute backward pass given the gradients of the output.
Args:
loss (:class:`torch.Tensor`): the loss value.
grad (:class:`torch.Tensor`): the output gradient.
"""
self.accumulate_step += 1
no_sync = self.is_torch_ddp and self.accumulate_step < self.accumulate_size
@ -84,7 +122,7 @@ class GradAccumDataloader:
(e.g. Dali dataloader), this class will automatically consume (load data for nothing) the remaining 2 batches.
Args:
optim (``Iterable``): Your dataloader object for gradient accumulation.
dataloader (``Iterable``): Your dataloader object for gradient accumulation.
accumulate_size (int): The number of steps to accumulate gradients.
"""
@ -96,15 +134,15 @@ class GradAccumDataloader:
def __getattr__(self, __name: str) -> Any:
return getattr(self.dataloader, __name)
def __len__(self):
def __len__(self) -> int:
return self.steps_per_epoch
def __iter__(self):
def __iter__(self) -> Iterable:
self._cur_step = 0
self._dataiter = iter(self.dataloader)
return self
def __next__(self) -> Any:
def __next__(self) -> Union[Tensor, Tuple[Tensor]]:
if self._cur_step < self.steps_per_epoch:
self._cur_step += 1
@ -137,13 +175,30 @@ class GradAccumLrSchedulerByStep(_LRScheduler):
self.accumulate_step = 0
@staticmethod
def compute_effective_steps_per_epoch(dataloader: Iterable, accumulate_size: int):
def compute_effective_steps_per_epoch(dataloader: Iterable, accumulate_size: int) -> int:
"""
Computes the number of effective training iterations. An effective iteration is defined
as the the aggregation of <accumulate_size> iterations. For examples, if accumulate_size = 4,
then 4 iterations are considered as one effective iteration.
Args:
dataloader (``Iterable``): Your dataloader object for gradient accumulation.
accumulate_size (int): The number of steps to accumulate gradients.
"""
return len(dataloader) // accumulate_size
def __getattr__(self, __name: str) -> Any:
return getattr(self.lr_scheduler, __name)
def step(self, *args, **kwargs):
def step(self, *args, **kwargs) -> None:
"""
Update the learning rate.
Args:
*args: positional arguments for the lr scheduler wrapped.
**kwargs: keyword arguments for the lr scheduler wrapped.
"""
self.accumulate_step += 1
if self.accumulate_step < self.accumulate_size:
pass
@ -151,19 +206,52 @@ class GradAccumLrSchedulerByStep(_LRScheduler):
self.accumulate_step = 0
self.lr_scheduler.step(*args, **kwargs)
def get_lr(self):
def get_lr(self) -> Tensor:
"""
Compute the next learning rate.
Returns:
Tensor: the upcoming learning rate.
"""
return self.lr_scheduler.get_lr()
def get_last_lr(self):
def get_last_lr(self) -> Tensor:
"""
Returns the current learning rate.
Returns:
Tensor: the current learning rate.
"""
return self.lr_scheduler.get_last_lr()
def print_lr(self, *args, **kwargs):
def print_lr(self, *args, **kwargs) -> None:
"""
Print he learning rate.
Args:
*args: positional arguments for the lr scheduler wrapped.
**kwargs: keyword arguments for the lr scheduler wrapped.
"""
self.lr_scheduler.print_lr(*args, **kwargs)
def state_dict(self) -> dict:
"""
Returns the states of the lr scheduler as dictionary.
Returns:
dict: the states of the lr scheduler.
"""
return self.lr_scheduler.state_dict()
def load_state_dict(self, state_dict: dict) -> None:
"""
Load the states of the lr scheduler from a dictionary object.
Returns:
dict: the states of the lr scheduler.
"""
self.lr_scheduler.load_state_dict(state_dict)
@ -188,7 +276,11 @@ class GradAccumGradientHandler:
self.accumulate_size = accumulate_size
self.accumulate_step = 0
def handle_gradient(self):
def handle_gradient(self) -> None:
"""
Handle gradients reduction only in the last gradient accumulation step.
"""
self.accumulate_step += 1
if self.accumulate_step < self.accumulate_size:
pass

View File

@ -12,6 +12,10 @@ class DataParallelGradientHandler(BaseGradientHandler):
:func:`handle_gradient` among a data parallel group.
For better performance, it bucketizes the gradients of all parameters that are
the same type to improve the efficiency of communication.
Args:
model (Module): Model where the gradients accumulate.
optimizer (Optimizer): Optimizer for updating the parameters.
"""
def handle_gradient(self):

View File

@ -14,6 +14,10 @@ class MoeGradientHandler(BaseGradientHandler):
:func:`handle_gradient` among a data parallel group.
For better performance, it bucketizes the gradients of all parameters that are
the same type to improve the efficiency of communication.
Args:
model (Module): Model where the gradients accumulate.
optimizer (Optimizer): Optimizer for updating the parameters.
"""
def __init__(self, model, optimizer=None):
@ -29,7 +33,6 @@ class MoeGradientHandler(BaseGradientHandler):
if global_data > 1:
epsize_param_dict = get_moe_epsize_param_dict(self._model)
# epsize is 1, indicating the params are replicated among processes in data parallelism
# use the ParallelMode.DATA to get data parallel group
# reduce gradients for all parameters in data parallelism

View File

@ -18,6 +18,10 @@ class PipelineSharedModuleGradientHandler(BaseGradientHandler):
:func:`handle_gradient` among all sub pipeline parallel groups.
For better performance, it bucketizes the gradients of all parameters that are
the same type to improve the efficiency of communication.
Args:
model (Module): Model where the gradients accumulate.
optimizer (Optimizer): Optimizer for updating the parameters.
"""
def handle_gradient(self):

View File

@ -12,6 +12,10 @@ class SequenceParallelGradientHandler(BaseGradientHandler):
:func:`handle_gradient` among a data parallel group.
For better performance, it bucketizes the gradients of all parameters that are
the same type to improve the efficiency of communication.
Args:
model (Module): Model where the gradients accumulate.
optimizer (Optimizer): Optimizer for updating the parameters.
"""
def handle_gradient(self):

View File

@ -8,6 +8,10 @@ class ZeROGradientHandler(BaseGradientHandler):
A all-reduce collective communication will be operated in
:func:`handle_gradient` among a data parallel group.
This class is specialized with ZeRO optimization.
Args:
model (Module): Model where the gradients accumulate.
optimizer (Optimizer): Optimizer for updating the parameters.
"""
def handle_gradient(self):

View File

@ -28,7 +28,11 @@ class BaseParamHookMgr(object):
handle = p.register_hook(functools.partial(hook_call, p))
p._base_param_hook = handle
def remove_hooks(self):
def remove_hooks(self) -> None:
"""
Remove hooks from model parameters.
"""
for p in self._param_list:
if p.requires_grad and hasattr(p, '_base_param_hook'):
p._base_param_hook.remove()

View File

@ -81,6 +81,9 @@ class PipelineSchedule(BaseSchedule):
tensor_shape: Union[torch.Size, List[int], Tuple[int]] = None,
scatter_gather_tensors: bool = False):
super().__init__(batch_data_process_func=batch_data_process_func)
assert num_microbatches > 0, f'expected num_microbatches to be larger then 1, but got {num_microbatches}'
self.num_microbatches = num_microbatches
self.dtype = torch.float
self.tensor_shape = tensor_shape
@ -150,7 +153,7 @@ class PipelineSchedule(BaseSchedule):
else:
return model(input_tensor, **batch_data)
def forward_step(self, engine, input_tensor, return_tensors, return_output_label=True, accum_loss=None):
def _forward_step(self, engine, input_tensor, return_tensors, return_output_label=True, accum_loss=None):
"""Forward step for passed-in model. If it is the first stage, the input tensor
is obtained from data_iterator, otherwise the passed-in input_tensor is used.
Returns output tensor. This is a helper function and can be ignored by users.
@ -186,7 +189,7 @@ class PipelineSchedule(BaseSchedule):
)
return output_tensor
def backward_step(self, engine, input_tensor, output_tensor, output_tensor_grad):
def _backward_step(self, engine, input_tensor, output_tensor, output_tensor_grad):
"""Backward step through the passed-in output tensor. If it is the last stage, the
output_tensor_grad is None, otherwise it is the gradients with respect to stage's output tensor.
Returns the gradients with respect to the input tensor (None if first stage).
@ -267,11 +270,11 @@ class PipelineSchedule(BaseSchedule):
input_tensor = comm.recv_forward(ft_shape,
dtype=self.dtype,
scatter_gather_tensors=self.scatter_gather_tensors)
output_tensor = self.forward_step(engine,
input_tensor,
return_tensors,
return_output_label=return_output_label,
accum_loss=accum_loss)
output_tensor = self._forward_step(engine,
input_tensor,
return_tensors,
return_output_label=return_output_label,
accum_loss=accum_loss)
if not gpc.is_last_rank(ParallelMode.PIPELINE):
bt_shape = output_tensor.shape
fs_checker = comm.send_tensor_meta(output_tensor, fs_checker)
@ -295,11 +298,11 @@ class PipelineSchedule(BaseSchedule):
for i in range(num_microbatches_remaining):
last_iteration = (i == (num_microbatches_remaining - 1))
output_tensor = self.forward_step(engine,
input_tensor,
return_tensors,
return_output_label=return_output_label,
accum_loss=accum_loss)
output_tensor = self._forward_step(engine,
input_tensor,
return_tensors,
return_output_label=return_output_label,
accum_loss=accum_loss)
if forward_only:
comm.send_forward(output_tensor, scatter_gather_tensors=self.scatter_gather_tensors)
@ -323,7 +326,7 @@ class PipelineSchedule(BaseSchedule):
input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0)
input_tensor_grad = self.backward_step(engine, input_tensor, output_tensor, output_tensor_grad)
input_tensor_grad = self._backward_step(engine, input_tensor, output_tensor, output_tensor_grad)
if last_iteration:
input_tensor = None
@ -344,7 +347,7 @@ class PipelineSchedule(BaseSchedule):
dtype=self.dtype,
scatter_gather_tensors=self.scatter_gather_tensors)
input_tensor_grad = self.backward_step(engine, input_tensor, output_tensor, output_tensor_grad)
input_tensor_grad = self._backward_step(engine, input_tensor, output_tensor, output_tensor_grad)
comm.send_backward(input_tensor_grad, scatter_gather_tensors=self.scatter_gather_tensors)
@ -358,8 +361,8 @@ class PipelineSchedule(BaseSchedule):
class InterleavedPipelineSchedule(PipelineSchedule):
def __init__(self,
num_microbatches,
num_model_chunks,
num_microbatches: int,
num_model_chunks: int,
batch_data_process_func: Callable = None,
tensor_shape: Union[torch.Size, List[int], Tuple[int]] = None,
scatter_gather_tensors: bool = False):
@ -378,6 +381,8 @@ class InterleavedPipelineSchedule(PipelineSchedule):
"""
assert num_microbatches % gpc.get_world_size(ParallelMode.PIPELINE) == 0, \
'num_microbatches must be an integer multiple of pipeline parallel world size'
assert isinstance(num_model_chunks, int) and num_model_chunks > 0, \
f'expected num_model_chunks to be an integer and larger than 0, but got {num_model_chunks}'
super().__init__(num_microbatches,
batch_data_process_func=batch_data_process_func,
tensor_shape=tensor_shape,
@ -409,13 +414,13 @@ class InterleavedPipelineSchedule(PipelineSchedule):
self.microbatch_offset[model_chunk_id] += self.microbatch_size
return self._move_to_device(data), self._move_to_device(label)
def forward_step(self,
engine,
model_chunk_id,
input_tensor,
return_tensors,
return_output_label=True,
accum_loss=None):
def _forward_step(self,
engine,
model_chunk_id,
input_tensor,
return_tensors,
return_output_label=True,
accum_loss=None):
"""Forward step for passed-in model. If it is the first stage, the input tensor
is obtained from data_iterator, otherwise the passed-in input_tensor is used.
Returns output tensor. This is a helper function and can be ignored by users.
@ -522,7 +527,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
model_chunk_id = (num_model_chunks - model_chunk_id - 1)
return model_chunk_id
def forward_step_helper(microbatch_id):
def _forward_step_helper(microbatch_id):
"""Helper method to run forward step with model split into chunks
(run set_virtual_pipeline_model_parallel_rank() before calling
forward_step())."""
@ -535,12 +540,12 @@ class InterleavedPipelineSchedule(PipelineSchedule):
len(output_tensors[model_chunk_id]):
input_tensors[model_chunk_id].append(None)
input_tensor = input_tensors[model_chunk_id][-1]
output_tensor = self.forward_step(engine,
model_chunk_id,
input_tensor,
return_tensors,
return_output_label=return_output_label,
accum_loss=accum_loss)
output_tensor = self._forward_step(engine,
model_chunk_id,
input_tensor,
return_tensors,
return_output_label=return_output_label,
accum_loss=accum_loss)
output_tensors[model_chunk_id].append(output_tensor)
# if forward-only, no need to save tensors for a backward pass
@ -550,7 +555,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
return output_tensor
def backward_step_helper(microbatch_id):
def _backward_step_helper(microbatch_id):
"""Helper method to run backward step with model split into chunks
(run set_virtual_pipeline_model_parallel_rank() before calling
backward_step())."""
@ -563,7 +568,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
input_tensor = input_tensors[model_chunk_id].pop(0)
output_tensor = output_tensors[model_chunk_id].pop(0)
output_tensor_grad = output_tensor_grads[model_chunk_id].pop(0)
input_tensor_grad = self.backward_step(engine, input_tensor, output_tensor, output_tensor_grad)
input_tensor_grad = self._backward_step(engine, input_tensor, output_tensor, output_tensor_grad)
return input_tensor_grad
@ -578,7 +583,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
for k in range(num_warmup_microbatches):
model_chunk_id = get_model_chunk_id(k, forward=True)
output_tensor = forward_step_helper(k)
output_tensor = _forward_step_helper(k)
if not gpc.is_pipeline_last_stage():
output_tensor_shapes[model_chunk_id] = output_tensor.shape
send_tensor_shape_flags[model_chunk_id] = comm.send_tensor_meta(output_tensor,
@ -633,11 +638,11 @@ class InterleavedPipelineSchedule(PipelineSchedule):
for k in range(num_microbatches_remaining):
# Forward pass.
forward_k = k + num_warmup_microbatches
output_tensor = forward_step_helper(forward_k)
output_tensor = _forward_step_helper(forward_k)
# Backward pass.
backward_k = k
input_tensor_grad = backward_step_helper(backward_k)
input_tensor_grad = _backward_step_helper(backward_k)
# Send output_tensor and input_tensor_grad, receive input_tensor
# and output_tensor_grad.
@ -708,7 +713,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
comm.recv_backward(output_tensor_shapes[num_model_chunks - 1],
scatter_gather_tensors=self.scatter_gather_tensors))
for k in range(num_microbatches_remaining, num_microbatches):
input_tensor_grad = backward_step_helper(k)
input_tensor_grad = _backward_step_helper(k)
next_backward_model_chunk_id = get_model_chunk_id(k + 1, forward=False)
recv_next = True
if gpc.is_pipeline_last_stage(ignore_virtual=True):