[booster] implemented mixed precision class (#3151)

* [booster] implemented mixed precision class

* polish code
pull/3162/head
Frank Lee 2023-03-17 11:00:15 +08:00 committed by GitHub
parent ecd643f1e4
commit ed19290560
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 410 additions and 40 deletions

View File

@ -2,4 +2,3 @@ from .accelerator import Accelerator
from .booster import Booster
from .environment_table import EnvironmentTable
from .plugin import Plugin
from .precision import Precision

View File

@ -1,37 +1,95 @@
from contextlib import contextmanager
from typing import Callable, Iterator, List, Optional, Tuple, Union
from typing import Callable, Iterable, Iterator, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from torch import Tensor
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader
from .mixed_precision import MixedPrecision, mixed_precision_factory
from .plugin import Plugin
__all__ = ['Booster']
class Booster:
"""
Booster is a high-level API for training neural networks. It provides a unified interface for
training with different precisio, accelerator, and plugin.
Examples:
>>> colossalai.launch(...)
>>> plugin = GeminiPlugin(stage=3, ...)
>>> booster = Booster(precision='fp16', plugin=plugin)
>>>
>>> model = GPT2()
>>> optimizer = Adam(model.parameters())
>>> dataloader = Dataloader(Dataset)
>>> lr_scheduler = LinearWarmupScheduler()
>>> criterion = GPTLMLoss()
>>>
>>> model, optimizer, lr_scheduler, dataloader = booster.boost(model, optimizer, lr_scheduler, dataloader)
>>>
>>> for epoch in range(max_epochs):
>>> for input_ids, attention_mask in dataloader:
>>> outputs = model(input_ids, attention_mask)
>>> loss = criterion(outputs.logits, input_ids)
>>> booster.backward(loss, optimizer)
>>> optimizer.step()
>>> lr_scheduler.step()
>>> optimizer.zero_grad()
Args:
device (str or torch.device): The device to run the training. Default: 'cuda'.
mixed_precision (str or MixedPrecision): The mixed precision to run the training. Default: None.
If the argument is a string, it can be 'fp16', 'fp16_apex', 'bf16', or 'fp8'.
'fp16' would use PyTorch AMP while `fp16_apex` would use Nvidia Apex.
plugin (Plugin): The plugin to run the training. Default: None.
"""
def __init__(self,
device: Union[str, torch.device] = 'cuda',
precision: str = 'fp32',
grad_clipping_type: str = 'norm',
grad_clipping_value: float = 0.0,
mixed_precision: Union[MixedPrecision, str] = None,
plugin: Optional[Plugin] = None) -> None:
# TODO: implement this method
pass
# validate and set precision
if isinstance(MixedPrecision, str):
# the user will take the default arguments for amp training
self.mixed_precision = mixed_precision_factory(mixed_precision)
elif isinstance(mixed_precision, MixedPrecision):
# the user can customize the arguments by passing the precision object
self.mixed_precision = mixed_precision
else:
raise ValueError(
f'Expected the argument mixed_precision to be a string or an instance of Precision, but got {type(mixed_precision)}.'
)
def boost(
self, *args: Union[nn.Module, Optimizer, LRScheduler, DataLoader]
) -> List[Union[nn.Module, Optimizer, LRScheduler, DataLoader]]:
# TODO: implement this method
pass
def boost(self, model: nn.Module, optimizer: Optimizer, criterion: Callable, lr_scheduler: LRScheduler,
dataloader: DataLoader) -> List[Union[nn.Module, Optimizer, LRScheduler, DataLoader]]:
"""
Boost the model, optimizer, criterion, lr_scheduler, and dataloader.
Args:
model (nn.Module): The model to be boosted.
optimizer (Optimizer): The optimizer to be boosted.
criterion (Callable): The criterion to be boosted.
lr_scheduler (LRScheduler): The lr_scheduler to be boosted.
dataloader (DataLoader): The dataloader to be boosted.
"""
# TODO(FrankLeeeee): consider multi-model and multi-optimizer case
# TODO(lsg): Add plugin control logic
# e.g.
# if self.plugin is not None and self.plugin.control_boost:
# ...
# transform model for mixed precision
model, optimizer, criterion = self.mixed_precision.configure(model, optimizer, criterion)
return model, optimizer, criterion, lr_scheduler, dataloader
def backward(self, loss: torch.Tensor, optimizer: Optimizer) -> None:
# TODO: implement this method
pass
# TODO: implement this method with plugin
optimizer.backward(loss)
def execute_pipeline(self,
data_iter: Iterator,

View File

@ -0,0 +1,3 @@
from .optimizer import OptimizerWrapper
__all__ = ['OptimizerWrapper']

View File

@ -0,0 +1,121 @@
from typing import Union
import torch.nn as nn
from torch import Tensor
from torch.optim import Optimizer
class OptimizerWrapper:
"""
A standard interface for optimizers wrapped by the Booster.
Args:
optim (Optimizer): The optimizer to be wrapped.
"""
def __init__(self, optim: Optimizer):
self.optim = optim
@property
def parameters(self):
params = []
for group in self.param_groups:
params += group['params']
return params
@property
def param_groups(self):
return self.optim.param_groups
@property
def defaults(self):
return self.optim.defaults
def add_param_group(self, *args, **kwargs):
return self.optim.add_param_group(*args, **kwargs)
def step(self, *args, **kwargs):
"""
Performs a single optimization step.
"""
return self.optim.step(*args, **kwargs)
def zero_grad(self, *args, **kwargs):
"""
Clears the gradients of all optimized `torch.Tensor`.
"""
self.optim.zero_grad(*args, **kwargs)
def backward(self, loss: Tensor, *args, **kwargs):
"""
Performs a backward pass on the loss.
"""
loss.backward(*args, **kwargs)
def state_dict(self):
"""
Returns the optimizer state.
"""
return self.optim.state_dict()
def load_state_dict(self, *args, **kwargs):
"""
Loads the optimizer state.
"""
self.optim.load_state_dict(*args, **kwargs)
def clip_grad_by_value(self, clip_value: float, *args, **kwargs) -> None:
"""
Clips gradient of an iterable of parameters at specified min and max values.
Args:
clip_value (float or int): maximum allowed value of the gradients. Gradients are clipped in the range
Note:
In PyTorch Torch 2.0 and above, you can pass in foreach=True as kwargs to clip_grad_value_ to use the
faster implementation. Please refer to the PyTorch documentation for more details.
"""
nn.utils.clip_grad_value_(self.parameters, clip_value, *args, **kwargs)
def clip_grad_by_norm(self,
max_norm: Union[float, int],
norm_type: Union[float, int] = 2.0,
error_if_nonfinite: bool = False,
*args,
**kwargs) -> Tensor:
"""
Clips gradient norm of an iterable of parameters.
Args:
max_norm (float or int): max norm of the gradients
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for infinity norm.
error_if_nonfinite (bool): if True, an error is raised if the total norm is non-finite. Default: False
Note:
In PyTorch Torch 2.0 and above, you can pass in foreach=True as kwargs to clip_grad_norm_ to use the
faster implementation. Please refer to the PyTorch documentation for more details.
"""
norm = nn.utils.clip_grad_norm_(self.parameters, max_norm, norm_type, error_if_nonfinite, *args, **kwargs)
return norm
def scale_loss(self, loss: Tensor):
"""
Scales the loss for mixed precision training.
Note: Only available for optimizers with mixed precision training.
Args:
loss (Tensor): The loss to be scaled.
"""
raise NotImplementedError(
"The method scale_loss is only available for optimizers with mixed precision training")
def unscale_grad(self):
"""
Unscale the gradients for mixed precision training.
Note: Only available for optimizers with mixed precision training.
"""
raise NotImplementedError(
"The method unscale_grad is only available for optimizers with mixed precision training")

View File

@ -0,0 +1,33 @@
from .bf16 import BF16MixedPrecision
from .fp8 import FP8MixedPrecision
from .fp16_apex import FP16ApexMixedPrecision
from .fp16_torch import FP16TorchMixedPrecision
from .mixed_precision_base import MixedPrecision
__all__ = [
'MixedPrecision', 'mixed_precision_factory', 'FP16_Apex_MixedPrecision', 'FP16_Torch_MixedPrecision',
'FP32_MixedPrecision', 'BF16_MixedPrecision', 'FP8_MixedPrecision'
]
_mixed_precision_mapping = {
'fp16': FP16TorchMixedPrecision,
'fp16_apex': FP16ApexMixedPrecision,
'bf16': BF16MixedPrecision,
'fp8': FP8MixedPrecision
}
def mixed_precision_factory(mixed_precision_type: str) -> MixedPrecision:
"""
Factory method to create mixed precision object
Args:
mixed_precision_type (str): mixed precision type, including None, 'fp16', 'fp16_apex', 'bf16', and 'fp8'.
"""
if mixed_precision_type in _mixed_precision_mapping:
return _mixed_precision_mapping[mixed_precision_type]()
else:
raise ValueError(
f'Mixed precision type {mixed_precision_type} is not supported, support types include {list(_mixed_precision_mapping.keys())}'
)

View File

@ -0,0 +1,5 @@
from .mixed_precision_base import MixedPrecision
class BF16MixedPrecision(MixedPrecision):
pass

View File

@ -0,0 +1,5 @@
from .mixed_precision_base import MixedPrecision
class FP16ApexMixedPrecision(MixedPrecision):
pass

View File

@ -0,0 +1,122 @@
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from torch import Tensor
from torch.optim import Optimizer
from ..interface import OptimizerWrapper
from .mixed_precision_base import MixedPrecision
__all__ = ['FP16_Torch_MixedPrecision', 'TorchAMPOptimizer', 'TorchAMPModule']
class TorchAMPOptimizer(OptimizerWrapper):
"""
Optimizer wrapper for mixed precision training in FP16 using PyTorch AMP.
Args:
optim (Optimizer): Optimizer to wrap.
init_scale (float): Initial scale factor. Default: 2**16.
growth_factor (float): Factor by which the scale is multiplied during
:meth:`torch.cuda.amp.GradScaler.step` if gradients were found to be finite
this iteration. Default: 2.0.
backoff_factor (float): Factor by which the scale is multiplied during
:meth:`torch.cuda.amp.GradScaler.step` if gradients were found to be infinite
this iteration. Default: 0.5.
growth_interval (int): Number of iterations between :meth:`torch.cuda.amp.GradScaler.step`
calls that may cause the scale to increase. Default: 2000.
"""
def __init__(self,
optim: Optimizer,
init_scale: float = 2.**16,
growth_factor: float = 2.0,
backoff_factor: float = 0.5,
growth_interval: int = 2000) -> None:
super().__init__(optim)
self.scaler = torch.cuda.amp.GradScaler(init_scale=init_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval)
def backward(self, loss: Tensor, *args, **kwargs) -> None:
scaled_loss = self.scale_loss(loss)
scaled_loss.backward(*args, **kwargs)
def step(self, *args, **kwargs) -> Optional[float]:
return self.scaler.step(self.optim, *args, **kwargs)
def scale_loss(self, loss: Tensor) -> Tensor:
return self.scaler.scale(loss)
def unscale_grad(self) -> None:
self.scaler.unscale_(self.optim)
def clip_grad_by_value(self, clip_value: float, *args, **kwargs) -> None:
self.unscale_grad()
super().clip_grad_by_value(clip_value, *args, **kwargs)
def clip_grad_by_norm(self,
max_norm: Union[float, int],
norm_type: Union[float, int] = 2.0,
error_if_nonfinite: bool = False,
*args,
**kwargs) -> None:
self.unscale_grad()
super().clip_grad_by_norm(max_norm, norm_type, error_if_nonfinite, *args, **kwargs)
class TorchAMPModule(nn.Module):
"""
Module wrapper for mixed precision training in FP16 using PyTorch AMP.
Args:
module (nn.Module): Module to wrap.
"""
def __init__(self, module: nn.Module):
super().__init__()
self.module = module
def forward(self, *args, **kwargs):
with torch.cuda.amp.autocast():
return self.module(*args, **kwargs)
class FP16TorchMixedPrecision(MixedPrecision):
"""
Precision for mixed precision training in FP16 using PyTorch AMP.
Args:
init_scale (float): Initial scale factor. Default: 2**16.
growth_factor (float): Factor by which the scale is multiplied during
:meth:`torch.cuda.amp.GradScaler.step` if gradients were found to be finite
this iteration. Default: 2.0.
backoff_factor (float): Factor by which the scale is multiplied during
:meth:`torch.cuda.amp.GradScaler.step` if gradients were found to be infinite
this iteration. Default: 0.5.
growth_interval (int): Number of iterations between :meth:`torch.cuda.amp.GradScaler.step`
calls that may cause the scale to increase. Default: 2000.
"""
def __init__(self,
init_scale: float = 2.**16,
growth_factor: float = 2.0,
backoff_factor: float = 0.5,
growth_interval: int = 2000) -> None:
super().__init__()
self.torch_amp_kwargs = dict(init_scale=init_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval)
def configure(self,
model: nn.Module,
optimizer: Optimizer,
criterion: Callable = None) -> Tuple[nn.Module, OptimizerWrapper, Callable]:
model = TorchAMPModule(model)
optimizer = TorchAMPOptimizer(optimizer, **self.torch_amp_kwargs)
if criterion is not None:
criterion = TorchAMPModule(criterion)
return model, optimizer, criterion

View File

@ -0,0 +1,5 @@
from .mixed_precision_base import MixedPrecision
class FP8MixedPrecision(MixedPrecision):
pass

View File

@ -0,0 +1,21 @@
from abc import ABC, abstractmethod
from typing import Callable, Tuple
import torch.nn as nn
from torch.optim import Optimizer
from ..interface import OptimizerWrapper
class MixedPrecision(ABC):
"""
An abstract class for mixed precision training.
"""
@abstractmethod
def configure(self,
model: nn.Module,
optimizer: Optimizer,
criterion: Callable = None) -> Tuple[nn.Module, OptimizerWrapper, Callable]:
# TODO: implement this method
pass

View File

@ -1,25 +0,0 @@
import torch
import torch.nn as nn
from torch.optim import Optimizer
__all__ = ['Precision']
class Precision:
def __init__(self, precision_type: torch.dtype, grad_clipping_type: str, grad_clipping_value: float):
self.precision_type = precision_type
self.grad_clipping_type = grad_clipping_type
self.grad_clipping_value = grad_clipping_value
def setup_model(self, model: nn.Module) -> nn.Module:
# TODO: implement this method
pass
def setup_optimizer(self, optimizer: Optimizer) -> Optimizer:
# TODO: implement this method
# inject grad clipping and unscale loss
pass
def scale_loss(self, loss: torch.Tensor) -> torch.Tensor:
pass

View File

@ -6,7 +6,7 @@ from ..registry import ModelAttribute, model_zoo
# ===============================
# Register single-sentence GPT
# ===============================
BATCH_SIZE = 2
BATCH_SIZE = 1 # it can only be 1 as GPT cannot handle batch sizes > 1 if no padding token is defined.
SEQ_LENGTH = 16

View File

@ -0,0 +1,23 @@
import torch
from torch.optim import Adam
from colossalai.booster.mixed_precision import FP16TorchMixedPrecision
from tests.kit.model_zoo import model_zoo
def test_torch_amp():
for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items():
model = model_fn().cuda()
optimizer = Adam(model.parameters(), lr=1e-3)
criterion = lambda x: x.mean()
data = data_gen_fn()
data = {k: v.cuda() if torch.is_tensor(v) else v for k, v in data.items()}
mixed_precision = FP16TorchMixedPrecision()
model, optimizer, criterion = mixed_precision.configure(model, optimizer, criterion)
output = model(**data)
output = output_transform_fn(output)
output_key = list(output.keys())[0]
loss = criterion(output[output_key])
optimizer.backward(loss)
optimizer.clip_grad_by_norm(1.0)
optimizer.step()