mirror of https://github.com/hpcaitech/ColossalAI
[booster] implemented mixed precision class (#3151)
* [booster] implemented mixed precision class * polish codepull/3162/head
parent
ecd643f1e4
commit
ed19290560
|
@ -2,4 +2,3 @@ from .accelerator import Accelerator
|
|||
from .booster import Booster
|
||||
from .environment_table import EnvironmentTable
|
||||
from .plugin import Plugin
|
||||
from .precision import Precision
|
||||
|
|
|
@ -1,37 +1,95 @@
|
|||
from contextlib import contextmanager
|
||||
from typing import Callable, Iterator, List, Optional, Tuple, Union
|
||||
from typing import Callable, Iterable, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from .mixed_precision import MixedPrecision, mixed_precision_factory
|
||||
from .plugin import Plugin
|
||||
|
||||
__all__ = ['Booster']
|
||||
|
||||
|
||||
class Booster:
|
||||
"""
|
||||
Booster is a high-level API for training neural networks. It provides a unified interface for
|
||||
training with different precisio, accelerator, and plugin.
|
||||
|
||||
Examples:
|
||||
>>> colossalai.launch(...)
|
||||
>>> plugin = GeminiPlugin(stage=3, ...)
|
||||
>>> booster = Booster(precision='fp16', plugin=plugin)
|
||||
>>>
|
||||
>>> model = GPT2()
|
||||
>>> optimizer = Adam(model.parameters())
|
||||
>>> dataloader = Dataloader(Dataset)
|
||||
>>> lr_scheduler = LinearWarmupScheduler()
|
||||
>>> criterion = GPTLMLoss()
|
||||
>>>
|
||||
>>> model, optimizer, lr_scheduler, dataloader = booster.boost(model, optimizer, lr_scheduler, dataloader)
|
||||
>>>
|
||||
>>> for epoch in range(max_epochs):
|
||||
>>> for input_ids, attention_mask in dataloader:
|
||||
>>> outputs = model(input_ids, attention_mask)
|
||||
>>> loss = criterion(outputs.logits, input_ids)
|
||||
>>> booster.backward(loss, optimizer)
|
||||
>>> optimizer.step()
|
||||
>>> lr_scheduler.step()
|
||||
>>> optimizer.zero_grad()
|
||||
|
||||
|
||||
Args:
|
||||
device (str or torch.device): The device to run the training. Default: 'cuda'.
|
||||
mixed_precision (str or MixedPrecision): The mixed precision to run the training. Default: None.
|
||||
If the argument is a string, it can be 'fp16', 'fp16_apex', 'bf16', or 'fp8'.
|
||||
'fp16' would use PyTorch AMP while `fp16_apex` would use Nvidia Apex.
|
||||
plugin (Plugin): The plugin to run the training. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
device: Union[str, torch.device] = 'cuda',
|
||||
precision: str = 'fp32',
|
||||
grad_clipping_type: str = 'norm',
|
||||
grad_clipping_value: float = 0.0,
|
||||
mixed_precision: Union[MixedPrecision, str] = None,
|
||||
plugin: Optional[Plugin] = None) -> None:
|
||||
# TODO: implement this method
|
||||
pass
|
||||
# validate and set precision
|
||||
if isinstance(MixedPrecision, str):
|
||||
# the user will take the default arguments for amp training
|
||||
self.mixed_precision = mixed_precision_factory(mixed_precision)
|
||||
elif isinstance(mixed_precision, MixedPrecision):
|
||||
# the user can customize the arguments by passing the precision object
|
||||
self.mixed_precision = mixed_precision
|
||||
else:
|
||||
raise ValueError(
|
||||
f'Expected the argument mixed_precision to be a string or an instance of Precision, but got {type(mixed_precision)}.'
|
||||
)
|
||||
|
||||
def boost(
|
||||
self, *args: Union[nn.Module, Optimizer, LRScheduler, DataLoader]
|
||||
) -> List[Union[nn.Module, Optimizer, LRScheduler, DataLoader]]:
|
||||
# TODO: implement this method
|
||||
pass
|
||||
def boost(self, model: nn.Module, optimizer: Optimizer, criterion: Callable, lr_scheduler: LRScheduler,
|
||||
dataloader: DataLoader) -> List[Union[nn.Module, Optimizer, LRScheduler, DataLoader]]:
|
||||
"""
|
||||
Boost the model, optimizer, criterion, lr_scheduler, and dataloader.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The model to be boosted.
|
||||
optimizer (Optimizer): The optimizer to be boosted.
|
||||
criterion (Callable): The criterion to be boosted.
|
||||
lr_scheduler (LRScheduler): The lr_scheduler to be boosted.
|
||||
dataloader (DataLoader): The dataloader to be boosted.
|
||||
"""
|
||||
# TODO(FrankLeeeee): consider multi-model and multi-optimizer case
|
||||
# TODO(lsg): Add plugin control logic
|
||||
# e.g.
|
||||
# if self.plugin is not None and self.plugin.control_boost:
|
||||
# ...
|
||||
# transform model for mixed precision
|
||||
model, optimizer, criterion = self.mixed_precision.configure(model, optimizer, criterion)
|
||||
return model, optimizer, criterion, lr_scheduler, dataloader
|
||||
|
||||
def backward(self, loss: torch.Tensor, optimizer: Optimizer) -> None:
|
||||
# TODO: implement this method
|
||||
pass
|
||||
# TODO: implement this method with plugin
|
||||
optimizer.backward(loss)
|
||||
|
||||
def execute_pipeline(self,
|
||||
data_iter: Iterator,
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
from .optimizer import OptimizerWrapper
|
||||
|
||||
__all__ = ['OptimizerWrapper']
|
|
@ -0,0 +1,121 @@
|
|||
from typing import Union
|
||||
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch.optim import Optimizer
|
||||
|
||||
|
||||
class OptimizerWrapper:
|
||||
"""
|
||||
A standard interface for optimizers wrapped by the Booster.
|
||||
|
||||
Args:
|
||||
optim (Optimizer): The optimizer to be wrapped.
|
||||
"""
|
||||
|
||||
def __init__(self, optim: Optimizer):
|
||||
self.optim = optim
|
||||
|
||||
@property
|
||||
def parameters(self):
|
||||
params = []
|
||||
|
||||
for group in self.param_groups:
|
||||
params += group['params']
|
||||
return params
|
||||
|
||||
@property
|
||||
def param_groups(self):
|
||||
return self.optim.param_groups
|
||||
|
||||
@property
|
||||
def defaults(self):
|
||||
return self.optim.defaults
|
||||
|
||||
def add_param_group(self, *args, **kwargs):
|
||||
return self.optim.add_param_group(*args, **kwargs)
|
||||
|
||||
def step(self, *args, **kwargs):
|
||||
"""
|
||||
Performs a single optimization step.
|
||||
"""
|
||||
return self.optim.step(*args, **kwargs)
|
||||
|
||||
def zero_grad(self, *args, **kwargs):
|
||||
"""
|
||||
Clears the gradients of all optimized `torch.Tensor`.
|
||||
"""
|
||||
self.optim.zero_grad(*args, **kwargs)
|
||||
|
||||
def backward(self, loss: Tensor, *args, **kwargs):
|
||||
"""
|
||||
Performs a backward pass on the loss.
|
||||
"""
|
||||
loss.backward(*args, **kwargs)
|
||||
|
||||
def state_dict(self):
|
||||
"""
|
||||
Returns the optimizer state.
|
||||
"""
|
||||
return self.optim.state_dict()
|
||||
|
||||
def load_state_dict(self, *args, **kwargs):
|
||||
"""
|
||||
Loads the optimizer state.
|
||||
"""
|
||||
self.optim.load_state_dict(*args, **kwargs)
|
||||
|
||||
def clip_grad_by_value(self, clip_value: float, *args, **kwargs) -> None:
|
||||
"""
|
||||
Clips gradient of an iterable of parameters at specified min and max values.
|
||||
|
||||
Args:
|
||||
clip_value (float or int): maximum allowed value of the gradients. Gradients are clipped in the range
|
||||
|
||||
Note:
|
||||
In PyTorch Torch 2.0 and above, you can pass in foreach=True as kwargs to clip_grad_value_ to use the
|
||||
faster implementation. Please refer to the PyTorch documentation for more details.
|
||||
"""
|
||||
nn.utils.clip_grad_value_(self.parameters, clip_value, *args, **kwargs)
|
||||
|
||||
def clip_grad_by_norm(self,
|
||||
max_norm: Union[float, int],
|
||||
norm_type: Union[float, int] = 2.0,
|
||||
error_if_nonfinite: bool = False,
|
||||
*args,
|
||||
**kwargs) -> Tensor:
|
||||
"""
|
||||
Clips gradient norm of an iterable of parameters.
|
||||
|
||||
Args:
|
||||
max_norm (float or int): max norm of the gradients
|
||||
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for infinity norm.
|
||||
error_if_nonfinite (bool): if True, an error is raised if the total norm is non-finite. Default: False
|
||||
|
||||
Note:
|
||||
In PyTorch Torch 2.0 and above, you can pass in foreach=True as kwargs to clip_grad_norm_ to use the
|
||||
faster implementation. Please refer to the PyTorch documentation for more details.
|
||||
"""
|
||||
norm = nn.utils.clip_grad_norm_(self.parameters, max_norm, norm_type, error_if_nonfinite, *args, **kwargs)
|
||||
return norm
|
||||
|
||||
def scale_loss(self, loss: Tensor):
|
||||
"""
|
||||
Scales the loss for mixed precision training.
|
||||
|
||||
Note: Only available for optimizers with mixed precision training.
|
||||
|
||||
Args:
|
||||
loss (Tensor): The loss to be scaled.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"The method scale_loss is only available for optimizers with mixed precision training")
|
||||
|
||||
def unscale_grad(self):
|
||||
"""
|
||||
Unscale the gradients for mixed precision training.
|
||||
|
||||
Note: Only available for optimizers with mixed precision training.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"The method unscale_grad is only available for optimizers with mixed precision training")
|
|
@ -0,0 +1,33 @@
|
|||
from .bf16 import BF16MixedPrecision
|
||||
from .fp8 import FP8MixedPrecision
|
||||
from .fp16_apex import FP16ApexMixedPrecision
|
||||
from .fp16_torch import FP16TorchMixedPrecision
|
||||
from .mixed_precision_base import MixedPrecision
|
||||
|
||||
__all__ = [
|
||||
'MixedPrecision', 'mixed_precision_factory', 'FP16_Apex_MixedPrecision', 'FP16_Torch_MixedPrecision',
|
||||
'FP32_MixedPrecision', 'BF16_MixedPrecision', 'FP8_MixedPrecision'
|
||||
]
|
||||
|
||||
_mixed_precision_mapping = {
|
||||
'fp16': FP16TorchMixedPrecision,
|
||||
'fp16_apex': FP16ApexMixedPrecision,
|
||||
'bf16': BF16MixedPrecision,
|
||||
'fp8': FP8MixedPrecision
|
||||
}
|
||||
|
||||
|
||||
def mixed_precision_factory(mixed_precision_type: str) -> MixedPrecision:
|
||||
"""
|
||||
Factory method to create mixed precision object
|
||||
|
||||
Args:
|
||||
mixed_precision_type (str): mixed precision type, including None, 'fp16', 'fp16_apex', 'bf16', and 'fp8'.
|
||||
"""
|
||||
|
||||
if mixed_precision_type in _mixed_precision_mapping:
|
||||
return _mixed_precision_mapping[mixed_precision_type]()
|
||||
else:
|
||||
raise ValueError(
|
||||
f'Mixed precision type {mixed_precision_type} is not supported, support types include {list(_mixed_precision_mapping.keys())}'
|
||||
)
|
|
@ -0,0 +1,5 @@
|
|||
from .mixed_precision_base import MixedPrecision
|
||||
|
||||
|
||||
class BF16MixedPrecision(MixedPrecision):
|
||||
pass
|
|
@ -0,0 +1,5 @@
|
|||
from .mixed_precision_base import MixedPrecision
|
||||
|
||||
|
||||
class FP16ApexMixedPrecision(MixedPrecision):
|
||||
pass
|
|
@ -0,0 +1,122 @@
|
|||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from ..interface import OptimizerWrapper
|
||||
from .mixed_precision_base import MixedPrecision
|
||||
|
||||
__all__ = ['FP16_Torch_MixedPrecision', 'TorchAMPOptimizer', 'TorchAMPModule']
|
||||
|
||||
|
||||
class TorchAMPOptimizer(OptimizerWrapper):
|
||||
"""
|
||||
Optimizer wrapper for mixed precision training in FP16 using PyTorch AMP.
|
||||
|
||||
Args:
|
||||
optim (Optimizer): Optimizer to wrap.
|
||||
init_scale (float): Initial scale factor. Default: 2**16.
|
||||
growth_factor (float): Factor by which the scale is multiplied during
|
||||
:meth:`torch.cuda.amp.GradScaler.step` if gradients were found to be finite
|
||||
this iteration. Default: 2.0.
|
||||
backoff_factor (float): Factor by which the scale is multiplied during
|
||||
:meth:`torch.cuda.amp.GradScaler.step` if gradients were found to be infinite
|
||||
this iteration. Default: 0.5.
|
||||
growth_interval (int): Number of iterations between :meth:`torch.cuda.amp.GradScaler.step`
|
||||
calls that may cause the scale to increase. Default: 2000.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
optim: Optimizer,
|
||||
init_scale: float = 2.**16,
|
||||
growth_factor: float = 2.0,
|
||||
backoff_factor: float = 0.5,
|
||||
growth_interval: int = 2000) -> None:
|
||||
super().__init__(optim)
|
||||
self.scaler = torch.cuda.amp.GradScaler(init_scale=init_scale,
|
||||
growth_factor=growth_factor,
|
||||
backoff_factor=backoff_factor,
|
||||
growth_interval=growth_interval)
|
||||
|
||||
def backward(self, loss: Tensor, *args, **kwargs) -> None:
|
||||
scaled_loss = self.scale_loss(loss)
|
||||
scaled_loss.backward(*args, **kwargs)
|
||||
|
||||
def step(self, *args, **kwargs) -> Optional[float]:
|
||||
return self.scaler.step(self.optim, *args, **kwargs)
|
||||
|
||||
def scale_loss(self, loss: Tensor) -> Tensor:
|
||||
return self.scaler.scale(loss)
|
||||
|
||||
def unscale_grad(self) -> None:
|
||||
self.scaler.unscale_(self.optim)
|
||||
|
||||
def clip_grad_by_value(self, clip_value: float, *args, **kwargs) -> None:
|
||||
self.unscale_grad()
|
||||
super().clip_grad_by_value(clip_value, *args, **kwargs)
|
||||
|
||||
def clip_grad_by_norm(self,
|
||||
max_norm: Union[float, int],
|
||||
norm_type: Union[float, int] = 2.0,
|
||||
error_if_nonfinite: bool = False,
|
||||
*args,
|
||||
**kwargs) -> None:
|
||||
self.unscale_grad()
|
||||
super().clip_grad_by_norm(max_norm, norm_type, error_if_nonfinite, *args, **kwargs)
|
||||
|
||||
|
||||
class TorchAMPModule(nn.Module):
|
||||
"""
|
||||
Module wrapper for mixed precision training in FP16 using PyTorch AMP.
|
||||
|
||||
Args:
|
||||
module (nn.Module): Module to wrap.
|
||||
"""
|
||||
|
||||
def __init__(self, module: nn.Module):
|
||||
super().__init__()
|
||||
self.module = module
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
with torch.cuda.amp.autocast():
|
||||
return self.module(*args, **kwargs)
|
||||
|
||||
|
||||
class FP16TorchMixedPrecision(MixedPrecision):
|
||||
"""
|
||||
Precision for mixed precision training in FP16 using PyTorch AMP.
|
||||
|
||||
Args:
|
||||
init_scale (float): Initial scale factor. Default: 2**16.
|
||||
growth_factor (float): Factor by which the scale is multiplied during
|
||||
:meth:`torch.cuda.amp.GradScaler.step` if gradients were found to be finite
|
||||
this iteration. Default: 2.0.
|
||||
backoff_factor (float): Factor by which the scale is multiplied during
|
||||
:meth:`torch.cuda.amp.GradScaler.step` if gradients were found to be infinite
|
||||
this iteration. Default: 0.5.
|
||||
growth_interval (int): Number of iterations between :meth:`torch.cuda.amp.GradScaler.step`
|
||||
calls that may cause the scale to increase. Default: 2000.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
init_scale: float = 2.**16,
|
||||
growth_factor: float = 2.0,
|
||||
backoff_factor: float = 0.5,
|
||||
growth_interval: int = 2000) -> None:
|
||||
super().__init__()
|
||||
self.torch_amp_kwargs = dict(init_scale=init_scale,
|
||||
growth_factor=growth_factor,
|
||||
backoff_factor=backoff_factor,
|
||||
growth_interval=growth_interval)
|
||||
|
||||
def configure(self,
|
||||
model: nn.Module,
|
||||
optimizer: Optimizer,
|
||||
criterion: Callable = None) -> Tuple[nn.Module, OptimizerWrapper, Callable]:
|
||||
model = TorchAMPModule(model)
|
||||
optimizer = TorchAMPOptimizer(optimizer, **self.torch_amp_kwargs)
|
||||
if criterion is not None:
|
||||
criterion = TorchAMPModule(criterion)
|
||||
return model, optimizer, criterion
|
|
@ -0,0 +1,5 @@
|
|||
from .mixed_precision_base import MixedPrecision
|
||||
|
||||
|
||||
class FP8MixedPrecision(MixedPrecision):
|
||||
pass
|
|
@ -0,0 +1,21 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, Tuple
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from ..interface import OptimizerWrapper
|
||||
|
||||
|
||||
class MixedPrecision(ABC):
|
||||
"""
|
||||
An abstract class for mixed precision training.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def configure(self,
|
||||
model: nn.Module,
|
||||
optimizer: Optimizer,
|
||||
criterion: Callable = None) -> Tuple[nn.Module, OptimizerWrapper, Callable]:
|
||||
# TODO: implement this method
|
||||
pass
|
|
@ -1,25 +0,0 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.optim import Optimizer
|
||||
|
||||
__all__ = ['Precision']
|
||||
|
||||
|
||||
class Precision:
|
||||
|
||||
def __init__(self, precision_type: torch.dtype, grad_clipping_type: str, grad_clipping_value: float):
|
||||
self.precision_type = precision_type
|
||||
self.grad_clipping_type = grad_clipping_type
|
||||
self.grad_clipping_value = grad_clipping_value
|
||||
|
||||
def setup_model(self, model: nn.Module) -> nn.Module:
|
||||
# TODO: implement this method
|
||||
pass
|
||||
|
||||
def setup_optimizer(self, optimizer: Optimizer) -> Optimizer:
|
||||
# TODO: implement this method
|
||||
# inject grad clipping and unscale loss
|
||||
pass
|
||||
|
||||
def scale_loss(self, loss: torch.Tensor) -> torch.Tensor:
|
||||
pass
|
|
@ -6,7 +6,7 @@ from ..registry import ModelAttribute, model_zoo
|
|||
# ===============================
|
||||
# Register single-sentence GPT
|
||||
# ===============================
|
||||
BATCH_SIZE = 2
|
||||
BATCH_SIZE = 1 # it can only be 1 as GPT cannot handle batch sizes > 1 if no padding token is defined.
|
||||
SEQ_LENGTH = 16
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,23 @@
|
|||
import torch
|
||||
from torch.optim import Adam
|
||||
|
||||
from colossalai.booster.mixed_precision import FP16TorchMixedPrecision
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
|
||||
|
||||
def test_torch_amp():
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items():
|
||||
model = model_fn().cuda()
|
||||
optimizer = Adam(model.parameters(), lr=1e-3)
|
||||
criterion = lambda x: x.mean()
|
||||
data = data_gen_fn()
|
||||
data = {k: v.cuda() if torch.is_tensor(v) else v for k, v in data.items()}
|
||||
mixed_precision = FP16TorchMixedPrecision()
|
||||
model, optimizer, criterion = mixed_precision.configure(model, optimizer, criterion)
|
||||
output = model(**data)
|
||||
output = output_transform_fn(output)
|
||||
output_key = list(output.keys())[0]
|
||||
loss = criterion(output[output_key])
|
||||
optimizer.backward(loss)
|
||||
optimizer.clip_grad_by_norm(1.0)
|
||||
optimizer.step()
|
Loading…
Reference in New Issue