mirror of https://github.com/hpcaitech/ColossalAI
Add gradient accumulation, fix lr scheduler
parent
0aa07e600c
commit
8aa21d6bc5
|
@ -235,7 +235,7 @@ def build_optimizer_wrapper(config, optimizer, model=None):
|
|||
return OPTIMIZER_WRAPPERS.get_module(mod_type)(optimizer, **config_)
|
||||
|
||||
|
||||
def build_lr_scheduler(config, optimizer, total_steps, num_steps_per_epoch):
|
||||
def build_lr_scheduler(config, optimizer):
|
||||
"""Returns a learning rate scheduler object of :class:`torch.optim.lr_scheduler`
|
||||
constructed from `config`, `optimizer`, `total_steps` and `num_steps_per_epoch`.
|
||||
|
||||
|
@ -255,8 +255,7 @@ def build_lr_scheduler(config, optimizer, total_steps, num_steps_per_epoch):
|
|||
config_ = config.copy()
|
||||
mod_type = config_.pop('type')
|
||||
# warmup epochs will overwrite warmup steps
|
||||
if 'warmup_epochs' in config_:
|
||||
warmup_epochs = config_.pop('warmup_epochs')
|
||||
config_['warmup_steps'] = int(num_steps_per_epoch * warmup_epochs)
|
||||
return LR_SCHEDULERS.get_module(mod_type)(optimizer, total_steps, num_steps_per_epoch=num_steps_per_epoch,
|
||||
**config_)
|
||||
# if 'warmup_epochs' in config_:
|
||||
# warmup_epochs = config_.pop('warmup_epochs')
|
||||
# config_['warmup_steps'] = int(num_steps_per_epoch * warmup_epochs)
|
||||
return LR_SCHEDULERS.get_module(mod_type)(optimizer, **config_)
|
||||
|
|
|
@ -44,7 +44,9 @@ class Engine:
|
|||
criterion: _Loss = None,
|
||||
optimizer: Optimizer = None,
|
||||
lr_scheduler: Optional[_LRScheduler] = None,
|
||||
schedule: BaseSchedule = None):
|
||||
schedule: BaseSchedule = None,
|
||||
gradient_accumulation: int = 1,
|
||||
lr_scheduler_step: str = 'epoch'):
|
||||
self.train_dataloader = train_dataloader
|
||||
self.test_dataloader = test_dataloader
|
||||
assert model is not None, "Engine requires a model"
|
||||
|
@ -54,6 +56,11 @@ class Engine:
|
|||
self.lr_scheduler = lr_scheduler
|
||||
self.schedule = schedule if schedule is not None \
|
||||
else NoPipelineSchedule()
|
||||
self.grad_accum_size = gradient_accumulation
|
||||
self.grad_accum_step = 0
|
||||
self.lr_step = 0 # for epoch updating
|
||||
if lr_scheduler_step != 'epoch':
|
||||
self.lr_step = 1
|
||||
self._logger = get_global_dist_logger()
|
||||
|
||||
# build gradient handler
|
||||
|
@ -89,9 +96,13 @@ class Engine:
|
|||
self._gradient_handlers.append(handler)
|
||||
|
||||
self.schedule.initialize(self.train_dataloader, self.model,
|
||||
self.criterion, self.optimizer,
|
||||
self.lr_scheduler)
|
||||
self.forward_only = False
|
||||
self.criterion, self.optimizer)
|
||||
self.schedule.grad_accum = self.grad_accum_size
|
||||
# add for robustness
|
||||
if self.optimizer is None:
|
||||
self.forward_only = True
|
||||
else:
|
||||
self.forward_only = False
|
||||
|
||||
def handle_gradient(self):
|
||||
"""Handles all-reduce operations of gradients across different parallel groups.
|
||||
|
@ -116,6 +127,7 @@ class Engine:
|
|||
"""Returns the neural network model in the engine.
|
||||
"""
|
||||
return self.model
|
||||
|
||||
def get_optimizer(self):
|
||||
"""Returns optimizier in the engine.
|
||||
"""
|
||||
|
@ -146,7 +158,10 @@ class Engine:
|
|||
def get_lr(self):
|
||||
"""Gets current learning rate.
|
||||
"""
|
||||
return self.schedule.get_lr()
|
||||
if self.lr_scheduler is not None:
|
||||
return self.lr_scheduler.get_lr()[0]
|
||||
else:
|
||||
return self.optimizer.param_groups[0]['lr']
|
||||
|
||||
def step(self, return_loss=True):
|
||||
"""A running step based on the schedule. Usually, it runs a training or
|
||||
|
@ -156,15 +171,27 @@ class Engine:
|
|||
:type return_loss: bool
|
||||
:return: (output, lablel, loss)
|
||||
"""
|
||||
self.schedule.zero_grad(forward_only=self.forward_only)
|
||||
if not self.forward_only and self.grad_accum_step == 0:
|
||||
self.schedule.zero_grad()
|
||||
|
||||
output, label, loss = self.schedule.forward_backward_step(
|
||||
forward_only=self.forward_only, return_loss=return_loss)
|
||||
|
||||
if not self.forward_only:
|
||||
# all reduce gradients
|
||||
self.handle_gradient()
|
||||
|
||||
self.schedule.step()
|
||||
self.grad_accum_step += 1
|
||||
if self.grad_accum_step == self.grad_accum_size:
|
||||
# all reduce gradients
|
||||
self.handle_gradient()
|
||||
self.schedule.step()
|
||||
if self.lr_scheduler is not None and self.lr_step:
|
||||
self.lr_scheduler.step()
|
||||
self.grad_accum_step = 0
|
||||
|
||||
return output, label, loss
|
||||
|
||||
def complete(self):
|
||||
"""Updating after a epoch.
|
||||
"""
|
||||
self.schedule.consume_batch()
|
||||
if self.lr_scheduler is not None and self.lr_step == 0:
|
||||
self.lr_scheduler.step()
|
||||
|
|
|
@ -15,6 +15,8 @@ class BaseSchedule(ABC):
|
|||
def __init__(self):
|
||||
self.initialized = False
|
||||
self.logger = get_global_dist_logger()
|
||||
self.grad_accum = 1
|
||||
self.training = False
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
|
@ -27,15 +29,13 @@ class BaseSchedule(ABC):
|
|||
dataloader=None,
|
||||
model=None,
|
||||
criterion=None,
|
||||
optimizer=None,
|
||||
lr_scheduler=None):
|
||||
optimizer=None):
|
||||
"""Initializes the schedule and set parameters before running.
|
||||
|
||||
:param dataloader: DataLoader in training or evaluation
|
||||
:param model: The neural network model
|
||||
:param criterion: Criterion for calculating loss
|
||||
:param optimizer: Optimizer for updating the parameters
|
||||
:param lr_scheduler: Learning rate scheduler in the process
|
||||
"""
|
||||
self.dataloader = dataloader
|
||||
assert model is not None, "Schedule requires a model"
|
||||
|
@ -44,7 +44,6 @@ class BaseSchedule(ABC):
|
|||
self.criterion = criterion
|
||||
assert optimizer is not None, "Schedule requires an optimizer"
|
||||
self.optimizer = optimizer
|
||||
self.lr_scheduler = lr_scheduler
|
||||
self.initialized = True
|
||||
|
||||
def check_initialized(self):
|
||||
|
@ -66,6 +65,13 @@ class BaseSchedule(ABC):
|
|||
data, label = next(self.data_iter)
|
||||
return self._move_to_device(data), self._move_to_device(label)
|
||||
|
||||
def consume_batch(self):
|
||||
while True:
|
||||
try:
|
||||
self.load_batch()
|
||||
except StopIteration:
|
||||
break
|
||||
|
||||
def _move_to_device(self, data):
|
||||
if isinstance(data, (
|
||||
tuple,
|
||||
|
@ -87,6 +93,7 @@ class BaseSchedule(ABC):
|
|||
:param mode: If True, the model will set as training mode. Otherwise, evaluation mode.
|
||||
"""
|
||||
self.check_initialized()
|
||||
self.training = mode
|
||||
if mode:
|
||||
self.model.train()
|
||||
else:
|
||||
|
@ -102,22 +109,11 @@ class BaseSchedule(ABC):
|
|||
self.check_initialized()
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
def get_lr(self):
|
||||
"""Returns the current learning rate.
|
||||
"""
|
||||
if self.lr_scheduler is not None:
|
||||
return self.lr_scheduler.get_lr()[0]
|
||||
else:
|
||||
return self.optimizer.param_groups[0]['lr']
|
||||
|
||||
def step(self):
|
||||
"""Updates the parameters and learning rate with the optimizer.
|
||||
"""
|
||||
self.check_initialized()
|
||||
self.optimizer.step()
|
||||
# update lr scheduler
|
||||
if self.lr_scheduler is not None:
|
||||
self.lr_scheduler.step()
|
||||
|
||||
@abstractmethod
|
||||
def forward_backward_step(self, forward_only=False, return_loss=True):
|
||||
|
|
|
@ -81,19 +81,20 @@ class NoPipelineSchedule(BaseSchedule):
|
|||
|
||||
@property
|
||||
def num_steps(self):
|
||||
return len(self.dataloader)
|
||||
length = len(self.dataloader)
|
||||
if self.training:
|
||||
length -= length % self.grad_accum
|
||||
return length
|
||||
|
||||
def initialize(self,
|
||||
dataloader,
|
||||
model,
|
||||
criterion,
|
||||
optimizer,
|
||||
lr_scheduler=None):
|
||||
dataloader=None,
|
||||
model=None,
|
||||
criterion=None,
|
||||
optimizer=None):
|
||||
super().initialize(dataloader,
|
||||
model,
|
||||
criterion,
|
||||
optimizer,
|
||||
lr_scheduler=lr_scheduler)
|
||||
optimizer)
|
||||
if isinstance(self.optimizer, (ZeroRedundancyOptimizer_Level_2,
|
||||
ZeroRedundancyOptimizer_Level_3)):
|
||||
self.use_zero_level_2_3 = True
|
||||
|
@ -147,6 +148,7 @@ class NoPipelineSchedule(BaseSchedule):
|
|||
output = (output,)
|
||||
if return_loss:
|
||||
loss = self.criterion(*output, *label)
|
||||
loss /= self.grad_accum
|
||||
|
||||
if not forward_only:
|
||||
# backward
|
||||
|
@ -168,7 +170,7 @@ class NoPipelineSchedule(BaseSchedule):
|
|||
loss.backward()
|
||||
|
||||
if return_loss:
|
||||
return output, label, loss
|
||||
return output, label, loss * self.grad_accum
|
||||
else:
|
||||
return output, None, None
|
||||
|
||||
|
@ -179,7 +181,3 @@ class NoPipelineSchedule(BaseSchedule):
|
|||
self._torch_amp_scaler.update()
|
||||
else:
|
||||
self.optimizer.step()
|
||||
|
||||
# update lr scheduler
|
||||
if self.lr_scheduler is not None:
|
||||
self.lr_scheduler.step()
|
||||
|
|
|
@ -119,19 +119,20 @@ class PipelineSchedule(BaseSchedule):
|
|||
|
||||
@property
|
||||
def num_steps(self):
|
||||
return len(self.dataloader)
|
||||
length = len(self.dataloader)
|
||||
if self.training:
|
||||
length -= length % self.grad_accum
|
||||
return length
|
||||
|
||||
def initialize(self,
|
||||
dataloader,
|
||||
model,
|
||||
criterion,
|
||||
optimizer,
|
||||
lr_scheduler=None):
|
||||
dataloader=None,
|
||||
model=None,
|
||||
criterion=None,
|
||||
optimizer=None):
|
||||
super().initialize(dataloader,
|
||||
model,
|
||||
criterion,
|
||||
optimizer,
|
||||
lr_scheduler=lr_scheduler)
|
||||
optimizer)
|
||||
if isinstance(self.optimizer, (ZeroRedundancyOptimizer_Level_2,
|
||||
ZeroRedundancyOptimizer_Level_3)):
|
||||
raise TypeError(
|
||||
|
@ -163,7 +164,7 @@ class PipelineSchedule(BaseSchedule):
|
|||
if return_loss:
|
||||
input_tensor, label = self.load_micro_batch()
|
||||
loss_reduced = self.criterion(output_tensor, *
|
||||
label) / self.num_microbatches
|
||||
label) / (self.num_microbatches * self.grad_accum)
|
||||
return_tensors.append(
|
||||
tuple((output_tensor, label[0], loss_reduced)))
|
||||
return loss_reduced
|
||||
|
@ -309,7 +310,7 @@ class PipelineSchedule(BaseSchedule):
|
|||
output, label, loss = tuple(map(list, zip(*return_tensors)))
|
||||
return (torch.cat(output, dim=0),
|
||||
torch.cat(label, dim=0),
|
||||
sum(loss))
|
||||
sum(loss) * self.grad_accum)
|
||||
else:
|
||||
return tuple((torch.cat(return_tensors, dim=0), None, None))
|
||||
else:
|
||||
|
|
|
@ -339,16 +339,15 @@ def initialize(config: Union[str, dict] = None,
|
|||
|
||||
lr_scheduler = None
|
||||
if hasattr(gpc.config, 'lr_scheduler'):
|
||||
if hasattr(gpc.config, 'num_steps'):
|
||||
total_steps = gpc.config.num_steps
|
||||
elif hasattr(gpc.config, 'num_epochs'):
|
||||
total_steps = int(gpc.config.num_epochs * len(train_dataloader))
|
||||
else:
|
||||
raise Exception(
|
||||
'Please specify training stopping criterion num_steps or num_epochs in your configuration.'
|
||||
)
|
||||
lr_scheduler = build_lr_scheduler(gpc.config.lr_scheduler, optimizer,
|
||||
total_steps, len(train_dataloader))
|
||||
# if hasattr(gpc.config, 'num_steps'):
|
||||
# total_steps = gpc.config.num_steps
|
||||
# elif hasattr(gpc.config, 'num_epochs'):
|
||||
# total_steps = int(gpc.config.num_epochs * len(train_dataloader))
|
||||
# else:
|
||||
# raise Exception(
|
||||
# 'Please specify training stopping criterion num_steps or num_epochs in your configuration.'
|
||||
# )
|
||||
lr_scheduler = build_lr_scheduler(gpc.config.lr_scheduler, optimizer)
|
||||
logger.info('Learning rate scheduler is created', ranks=[0])
|
||||
|
||||
# pipeline or no pipeline schedule
|
||||
|
|
|
@ -147,6 +147,7 @@ class Trainer:
|
|||
if self.exceed_max_step():
|
||||
# stop when max iter is reached
|
||||
break
|
||||
self._engine.complete()
|
||||
self._timer.stop('train-epoch', keep_in_history=True)
|
||||
self.call_hooks('after_train_epoch')
|
||||
self._timer.reset('train-step')
|
||||
|
|
|
@ -8,10 +8,10 @@ BATCH_SIZE = 512
|
|||
IMG_SIZE = 32
|
||||
PATCH_SIZE = 4
|
||||
DIM = 512
|
||||
NUM_ATTENTION_HEADS = 8
|
||||
NUM_ATTENTION_HEADS = 2
|
||||
SUMMA_DIM = 2
|
||||
NUM_CLASSES = 10
|
||||
DEPTH = 6
|
||||
DEPTH = 1
|
||||
|
||||
train_data = dict(
|
||||
dataset=dict(
|
||||
|
@ -127,14 +127,14 @@ hooks = [
|
|||
dict(type='LogMetricByEpochHook'),
|
||||
dict(type='Accuracy2DHook'),
|
||||
dict(type='LossHook'),
|
||||
dict(type='TensorboardHook', log_dir='./tfb_logs'),
|
||||
# dict(type='TensorboardHook', log_dir='./tfb_logs'),
|
||||
# dict(type='SaveCheckpointHook', interval=5, checkpoint_dir='./ckpt'),
|
||||
# dict(type='LoadCheckpointHook', epoch=20, checkpoint_dir='./ckpt')
|
||||
]
|
||||
|
||||
parallel = dict(
|
||||
pipeline=dict(size=1),
|
||||
tensor=dict(size=4, mode='2d'),
|
||||
tensor=dict(size=1, mode='2d'),
|
||||
)
|
||||
|
||||
# for fp16 training
|
||||
|
@ -146,7 +146,8 @@ parallel = dict(
|
|||
|
||||
lr_scheduler = dict(
|
||||
type='LinearWarmupLR',
|
||||
warmup_epochs=5
|
||||
total_steps=60,
|
||||
warmup_steps=5
|
||||
)
|
||||
|
||||
# only needed when pipeline parallel is used
|
||||
|
|
|
@ -17,7 +17,8 @@ def run_trainer():
|
|||
criterion=criterion,
|
||||
optimizer=optimizer,
|
||||
lr_scheduler=lr_scheduler,
|
||||
schedule=schedule
|
||||
schedule=schedule,
|
||||
gradient_accumulation=5,
|
||||
)
|
||||
logger.info("engine is built", ranks=[0])
|
||||
|
||||
|
|
Loading…
Reference in New Issue