mirror of https://github.com/hpcaitech/ColossalAI
[pipeline] set optimizer to optional in execute_pipeline (#4630)
* set optimizer to optional in execute_pipeline * arrange device and mixed precision in booster init * fix execute_pipeline in booster.pypull/4678/head
parent
c3d5fa3bac
commit
660eed9124
|
@ -49,7 +49,9 @@ class Booster:
|
|||
```
|
||||
|
||||
Args:
|
||||
device (str or torch.device): The device to run the training. Default: 'cuda'.
|
||||
device (str or torch.device): The device to run the training. Default: None.
|
||||
If plugin is not used or plugin doesn't control the device,
|
||||
this argument will be set as training device ('cuda' will be used if argument is None).
|
||||
mixed_precision (str or MixedPrecision): The mixed precision to run the training. Default: None.
|
||||
If the argument is a string, it can be 'fp16', 'fp16_apex', 'bf16', or 'fp8'.
|
||||
'fp16' would use PyTorch AMP while `fp16_apex` would use Nvidia Apex.
|
||||
|
@ -57,7 +59,7 @@ class Booster:
|
|||
"""
|
||||
|
||||
def __init__(self,
|
||||
device: str = 'cuda',
|
||||
device: Optional[str] = None,
|
||||
mixed_precision: Union[MixedPrecision, str] = None,
|
||||
plugin: Optional[Plugin] = None) -> None:
|
||||
if plugin is not None:
|
||||
|
@ -68,13 +70,16 @@ class Booster:
|
|||
# set accelerator
|
||||
if self.plugin and self.plugin.control_device():
|
||||
self.accelerator = None
|
||||
warnings.warn('The plugin will control the accelerator, so the device argument will be ignored.')
|
||||
if device is not None:
|
||||
warnings.warn('The plugin will control the accelerator, so the device argument will be ignored.')
|
||||
else:
|
||||
device = device or 'cuda'
|
||||
self.accelerator = Accelerator(device)
|
||||
|
||||
# set precision
|
||||
if self.plugin and self.plugin.control_precision():
|
||||
warnings.warn('The plugin will control the precision, so the mixed_precision argument will be ignored.')
|
||||
if mixed_precision is not None:
|
||||
warnings.warn('The plugin will control the precision, so the mixed_precision argument will be ignored.')
|
||||
self.mixed_precision = None
|
||||
elif mixed_precision is None:
|
||||
self.mixed_precision = None
|
||||
|
@ -146,7 +151,7 @@ class Booster:
|
|||
data_iter: Iterator,
|
||||
model: nn.Module,
|
||||
criterion: Callable[[Any, Any], torch.Tensor],
|
||||
optimizer: Optimizer,
|
||||
optimizer: Optional[Optimizer] = None,
|
||||
return_loss: bool = True,
|
||||
return_outputs: bool = False) -> dict:
|
||||
# run pipeline forward backward pass
|
||||
|
|
|
@ -443,15 +443,15 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
data_iter: Iterator,
|
||||
model: HybridParallelModule,
|
||||
criterion: Callable[[Any, Any], torch.Tensor],
|
||||
optimizer: Union[HybridParallelNaiveOptimizer, HybridParallelAMPOptimizer,
|
||||
HybridParallelZeroOptimizer],
|
||||
optimizer: Optional[Union[HybridParallelNaiveOptimizer, HybridParallelAMPOptimizer,
|
||||
HybridParallelZeroOptimizer]] = None,
|
||||
return_loss: bool = True,
|
||||
return_outputs: bool = False) -> dict:
|
||||
assert self.enable_pipeline_parallelism, 'pipeline parallelism is not enabled'
|
||||
# return loss or outputs if needed
|
||||
ctx = optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync()
|
||||
with ctx:
|
||||
outputs = self.schedule.forward_backward_step(model, optimizer, data_iter, criterion, return_loss,
|
||||
outputs = self.schedule.forward_backward_step(model, data_iter, criterion, optimizer, return_loss,
|
||||
return_outputs)
|
||||
model.sync_shared_params()
|
||||
if isinstance(optimizer, HybridParallelZeroOptimizer):
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from abc import abstractmethod
|
||||
from typing import Any, Callable, Iterator
|
||||
from typing import Any, Callable, Iterator, Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
@ -15,7 +15,7 @@ class PipelinePluginBase(Plugin):
|
|||
data_iter: Iterator,
|
||||
model: ModelWrapper,
|
||||
criterion: Callable[[Any, Any], torch.Tensor],
|
||||
optimizer: OptimizerWrapper,
|
||||
optimizer: Optional[OptimizerWrapper] = None,
|
||||
return_loss: bool = True,
|
||||
return_outputs: bool = False) -> dict:
|
||||
pass
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Any, Callable, Iterable
|
||||
from typing import Any, Callable, Iterable, Optional
|
||||
|
||||
from torch import Tensor
|
||||
from torch.nn import Module
|
||||
|
@ -14,18 +14,18 @@ class PipelineSchedule:
|
|||
|
||||
def forward_backward_step(self,
|
||||
model: Module,
|
||||
optimizer: OptimizerWrapper,
|
||||
data_iter: Iterable,
|
||||
criterion: Callable[[Any, Any], Tensor],
|
||||
optimizer: Optional[OptimizerWrapper] = None,
|
||||
return_loss: bool = False,
|
||||
return_outputs: bool = False) -> dict:
|
||||
"""Forward and backward step for pipeline training.
|
||||
|
||||
Args:
|
||||
model (Module): Model to be trained.
|
||||
optimizer (OptimizerWrapper): Optimizer to be used.
|
||||
data_iter (Iterable): Data iterator.
|
||||
criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor.
|
||||
optimizer (OptimizerWrapper, optional): Optimizer to be used. Can be None when only forward is executed. Defaults to None.
|
||||
return_loss (bool, optional): Whether to return loss. Defaults to False. Whether to return loss.
|
||||
return_outputs (bool, optional): Whether to return model outputs. Defaults to False. Whether to return model outputs.
|
||||
|
||||
|
|
|
@ -237,18 +237,18 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
|
||||
def forward_backward_step(self,
|
||||
model_chunk: Module,
|
||||
optimizer: OptimizerWrapper,
|
||||
data_iter: Iterable,
|
||||
criterion: Callable[..., Any],
|
||||
optimizer: Optional[OptimizerWrapper] = None,
|
||||
return_loss: bool = False,
|
||||
return_outputs: bool = False) -> dict:
|
||||
"""Runs interleaved 1F1B schedule, with communication between pipeline stages.
|
||||
|
||||
Args:
|
||||
model_chunk (List[Module]): Model Chunk to be trained.
|
||||
optimizer (OptimizerWrapper): Optimizer to be used.
|
||||
data_iter (Iterable): Data iterator.
|
||||
criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor.
|
||||
optimizer (OptimizerWrapper, optional): Optimizer to be used. Can be None when only forward is executed. Defaults to None.
|
||||
return_loss (bool, optional): Whether to return loss. Defaults to False. Whether to return loss.
|
||||
return_outputs (bool, optional): Whether to return model outputs. Defaults to False. Whether to return model outputs.
|
||||
|
||||
|
@ -256,6 +256,8 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
dict: A dict with keys: 'loss' and 'outputs'.
|
||||
"""
|
||||
forward_only = not torch.is_grad_enabled()
|
||||
if optimizer is None:
|
||||
assert forward_only, "Optimizer should be passed when doing backward."
|
||||
|
||||
self.load_batch(data_iter)
|
||||
num_model_chunks = len(model_chunk)
|
||||
|
|
|
@ -210,18 +210,18 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
|
||||
def forward_backward_step(self,
|
||||
model: Module,
|
||||
optimizer: OptimizerWrapper,
|
||||
data_iter: Iterable,
|
||||
criterion: Callable[..., Any],
|
||||
optimizer: Optional[OptimizerWrapper] = None,
|
||||
return_loss: bool = False,
|
||||
return_outputs: bool = False) -> dict:
|
||||
"""Runs non-interleaved 1F1B schedule, with communication between pipeline stages.
|
||||
|
||||
Args:
|
||||
model (Module): Model to be trained.
|
||||
optimizer (OptimizerWrapper): Optimizer to be used.
|
||||
data_iter (Iterable): Data iterator.
|
||||
criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor.
|
||||
optimizer (OptimizerWrapper, optional): Optimizer to be used. Can be None when only forward is executed. Defaults to None.
|
||||
return_loss (bool, optional): Whether to return loss. Defaults to False. Whether to return loss.
|
||||
return_outputs (bool, optional): Whether to return model outputs. Defaults to False. Whether to return model outputs.
|
||||
|
||||
|
@ -229,6 +229,8 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
dict: A dict with keys: 'loss' and 'outputs'.
|
||||
"""
|
||||
forward_only = not torch.is_grad_enabled()
|
||||
if optimizer is None:
|
||||
assert forward_only, "Optimizer should be passed when doing backward."
|
||||
|
||||
self.load_batch(data_iter)
|
||||
|
||||
|
|
|
@ -46,7 +46,6 @@ def move_to_cuda(batch):
|
|||
@torch.no_grad()
|
||||
def evaluate_model(
|
||||
model: nn.Module,
|
||||
optimizer,
|
||||
criterion,
|
||||
test_dataloader: Union[DataLoader, List[DataLoader]],
|
||||
num_labels: int,
|
||||
|
@ -71,12 +70,7 @@ def evaluate_model(
|
|||
current_rank = dist.get_rank()
|
||||
#TODO pass dataloader to execute_pipeline directly
|
||||
batch = iter([batch])
|
||||
outputs = booster.execute_pipeline(batch,
|
||||
model,
|
||||
criterion,
|
||||
optimizer,
|
||||
return_loss=True,
|
||||
return_outputs=True)
|
||||
outputs = booster.execute_pipeline(batch, model, criterion, return_loss=True, return_outputs=True)
|
||||
|
||||
if booster.plugin.stage_manager.is_last_stage():
|
||||
val_loss = outputs["loss"]
|
||||
|
@ -304,7 +298,7 @@ def main():
|
|||
for epoch in range(NUM_EPOCHS):
|
||||
train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, train_dataloader, booster, coordinator)
|
||||
|
||||
results = evaluate_model(model, optimizer, _criterion, test_dataloader, data_builder.num_labels, args.task,
|
||||
results = evaluate_model(model, _criterion, test_dataloader, data_builder.num_labels, args.task,
|
||||
data_builder.eval_splits, booster, coordinator)
|
||||
|
||||
if coordinator.is_master():
|
||||
|
|
|
@ -110,9 +110,9 @@ def examine_pp(num_micro_batches):
|
|||
torch_loss.backward()
|
||||
|
||||
pp_ret = schedule.forward_backward_step(sharded_model,
|
||||
pp_optimizer,
|
||||
iter(input_list),
|
||||
criterion,
|
||||
pp_optimizer,
|
||||
return_loss=True,
|
||||
return_outputs=True)
|
||||
|
||||
|
|
|
@ -90,9 +90,9 @@ def examine_pp():
|
|||
torch_loss.backward()
|
||||
|
||||
pp_ret = schedule.forward_backward_step(sharded_model,
|
||||
pp_optimizer,
|
||||
iter(input_list),
|
||||
criterion,
|
||||
pp_optimizer,
|
||||
return_loss=True,
|
||||
return_outputs=True)
|
||||
|
||||
|
|
Loading…
Reference in New Issue