|
|
|
@ -49,7 +49,9 @@ class Booster:
|
|
|
|
|
``` |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
device (str or torch.device): The device to run the training. Default: 'cuda'. |
|
|
|
|
device (str or torch.device): The device to run the training. Default: None. |
|
|
|
|
If plugin is not used or plugin doesn't control the device, |
|
|
|
|
this argument will be set as training device ('cuda' will be used if argument is None). |
|
|
|
|
mixed_precision (str or MixedPrecision): The mixed precision to run the training. Default: None. |
|
|
|
|
If the argument is a string, it can be 'fp16', 'fp16_apex', 'bf16', or 'fp8'. |
|
|
|
|
'fp16' would use PyTorch AMP while `fp16_apex` would use Nvidia Apex. |
|
|
|
@ -57,7 +59,7 @@ class Booster:
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
def __init__(self, |
|
|
|
|
device: str = 'cuda', |
|
|
|
|
device: Optional[str] = None, |
|
|
|
|
mixed_precision: Union[MixedPrecision, str] = None, |
|
|
|
|
plugin: Optional[Plugin] = None) -> None: |
|
|
|
|
if plugin is not None: |
|
|
|
@ -68,13 +70,16 @@ class Booster:
|
|
|
|
|
# set accelerator |
|
|
|
|
if self.plugin and self.plugin.control_device(): |
|
|
|
|
self.accelerator = None |
|
|
|
|
warnings.warn('The plugin will control the accelerator, so the device argument will be ignored.') |
|
|
|
|
if device is not None: |
|
|
|
|
warnings.warn('The plugin will control the accelerator, so the device argument will be ignored.') |
|
|
|
|
else: |
|
|
|
|
device = device or 'cuda' |
|
|
|
|
self.accelerator = Accelerator(device) |
|
|
|
|
|
|
|
|
|
# set precision |
|
|
|
|
if self.plugin and self.plugin.control_precision(): |
|
|
|
|
warnings.warn('The plugin will control the precision, so the mixed_precision argument will be ignored.') |
|
|
|
|
if mixed_precision is not None: |
|
|
|
|
warnings.warn('The plugin will control the precision, so the mixed_precision argument will be ignored.') |
|
|
|
|
self.mixed_precision = None |
|
|
|
|
elif mixed_precision is None: |
|
|
|
|
self.mixed_precision = None |
|
|
|
@ -146,7 +151,7 @@ class Booster:
|
|
|
|
|
data_iter: Iterator, |
|
|
|
|
model: nn.Module, |
|
|
|
|
criterion: Callable[[Any, Any], torch.Tensor], |
|
|
|
|
optimizer: Optimizer, |
|
|
|
|
optimizer: Optional[Optimizer] = None, |
|
|
|
|
return_loss: bool = True, |
|
|
|
|
return_outputs: bool = False) -> dict: |
|
|
|
|
# run pipeline forward backward pass |
|
|
|
|