Add gradient accumulation, fix lr scheduler

pull/17/head
1SAA 2021-11-08 15:48:27 +08:00
parent 0aa07e600c
commit 8aa21d6bc5
9 changed files with 93 additions and 70 deletions

View File

@ -235,7 +235,7 @@ def build_optimizer_wrapper(config, optimizer, model=None):
return OPTIMIZER_WRAPPERS.get_module(mod_type)(optimizer, **config_) return OPTIMIZER_WRAPPERS.get_module(mod_type)(optimizer, **config_)
def build_lr_scheduler(config, optimizer, total_steps, num_steps_per_epoch): def build_lr_scheduler(config, optimizer):
"""Returns a learning rate scheduler object of :class:`torch.optim.lr_scheduler` """Returns a learning rate scheduler object of :class:`torch.optim.lr_scheduler`
constructed from `config`, `optimizer`, `total_steps` and `num_steps_per_epoch`. constructed from `config`, `optimizer`, `total_steps` and `num_steps_per_epoch`.
@ -255,8 +255,7 @@ def build_lr_scheduler(config, optimizer, total_steps, num_steps_per_epoch):
config_ = config.copy() config_ = config.copy()
mod_type = config_.pop('type') mod_type = config_.pop('type')
# warmup epochs will overwrite warmup steps # warmup epochs will overwrite warmup steps
if 'warmup_epochs' in config_: # if 'warmup_epochs' in config_:
warmup_epochs = config_.pop('warmup_epochs') # warmup_epochs = config_.pop('warmup_epochs')
config_['warmup_steps'] = int(num_steps_per_epoch * warmup_epochs) # config_['warmup_steps'] = int(num_steps_per_epoch * warmup_epochs)
return LR_SCHEDULERS.get_module(mod_type)(optimizer, total_steps, num_steps_per_epoch=num_steps_per_epoch, return LR_SCHEDULERS.get_module(mod_type)(optimizer, **config_)
**config_)

View File

@ -44,7 +44,9 @@ class Engine:
criterion: _Loss = None, criterion: _Loss = None,
optimizer: Optimizer = None, optimizer: Optimizer = None,
lr_scheduler: Optional[_LRScheduler] = None, lr_scheduler: Optional[_LRScheduler] = None,
schedule: BaseSchedule = None): schedule: BaseSchedule = None,
gradient_accumulation: int = 1,
lr_scheduler_step: str = 'epoch'):
self.train_dataloader = train_dataloader self.train_dataloader = train_dataloader
self.test_dataloader = test_dataloader self.test_dataloader = test_dataloader
assert model is not None, "Engine requires a model" assert model is not None, "Engine requires a model"
@ -54,6 +56,11 @@ class Engine:
self.lr_scheduler = lr_scheduler self.lr_scheduler = lr_scheduler
self.schedule = schedule if schedule is not None \ self.schedule = schedule if schedule is not None \
else NoPipelineSchedule() else NoPipelineSchedule()
self.grad_accum_size = gradient_accumulation
self.grad_accum_step = 0
self.lr_step = 0 # for epoch updating
if lr_scheduler_step != 'epoch':
self.lr_step = 1
self._logger = get_global_dist_logger() self._logger = get_global_dist_logger()
# build gradient handler # build gradient handler
@ -89,8 +96,12 @@ class Engine:
self._gradient_handlers.append(handler) self._gradient_handlers.append(handler)
self.schedule.initialize(self.train_dataloader, self.model, self.schedule.initialize(self.train_dataloader, self.model,
self.criterion, self.optimizer, self.criterion, self.optimizer)
self.lr_scheduler) self.schedule.grad_accum = self.grad_accum_size
# add for robustness
if self.optimizer is None:
self.forward_only = True
else:
self.forward_only = False self.forward_only = False
def handle_gradient(self): def handle_gradient(self):
@ -116,6 +127,7 @@ class Engine:
"""Returns the neural network model in the engine. """Returns the neural network model in the engine.
""" """
return self.model return self.model
def get_optimizer(self): def get_optimizer(self):
"""Returns optimizier in the engine. """Returns optimizier in the engine.
""" """
@ -146,7 +158,10 @@ class Engine:
def get_lr(self): def get_lr(self):
"""Gets current learning rate. """Gets current learning rate.
""" """
return self.schedule.get_lr() if self.lr_scheduler is not None:
return self.lr_scheduler.get_lr()[0]
else:
return self.optimizer.param_groups[0]['lr']
def step(self, return_loss=True): def step(self, return_loss=True):
"""A running step based on the schedule. Usually, it runs a training or """A running step based on the schedule. Usually, it runs a training or
@ -156,15 +171,27 @@ class Engine:
:type return_loss: bool :type return_loss: bool
:return: (output, lablel, loss) :return: (output, lablel, loss)
""" """
self.schedule.zero_grad(forward_only=self.forward_only) if not self.forward_only and self.grad_accum_step == 0:
self.schedule.zero_grad()
output, label, loss = self.schedule.forward_backward_step( output, label, loss = self.schedule.forward_backward_step(
forward_only=self.forward_only, return_loss=return_loss) forward_only=self.forward_only, return_loss=return_loss)
if not self.forward_only: if not self.forward_only:
self.grad_accum_step += 1
if self.grad_accum_step == self.grad_accum_size:
# all reduce gradients # all reduce gradients
self.handle_gradient() self.handle_gradient()
self.schedule.step() self.schedule.step()
if self.lr_scheduler is not None and self.lr_step:
self.lr_scheduler.step()
self.grad_accum_step = 0
return output, label, loss return output, label, loss
def complete(self):
"""Updating after a epoch.
"""
self.schedule.consume_batch()
if self.lr_scheduler is not None and self.lr_step == 0:
self.lr_scheduler.step()

View File

@ -15,6 +15,8 @@ class BaseSchedule(ABC):
def __init__(self): def __init__(self):
self.initialized = False self.initialized = False
self.logger = get_global_dist_logger() self.logger = get_global_dist_logger()
self.grad_accum = 1
self.training = False
@property @property
@abstractmethod @abstractmethod
@ -27,15 +29,13 @@ class BaseSchedule(ABC):
dataloader=None, dataloader=None,
model=None, model=None,
criterion=None, criterion=None,
optimizer=None, optimizer=None):
lr_scheduler=None):
"""Initializes the schedule and set parameters before running. """Initializes the schedule and set parameters before running.
:param dataloader: DataLoader in training or evaluation :param dataloader: DataLoader in training or evaluation
:param model: The neural network model :param model: The neural network model
:param criterion: Criterion for calculating loss :param criterion: Criterion for calculating loss
:param optimizer: Optimizer for updating the parameters :param optimizer: Optimizer for updating the parameters
:param lr_scheduler: Learning rate scheduler in the process
""" """
self.dataloader = dataloader self.dataloader = dataloader
assert model is not None, "Schedule requires a model" assert model is not None, "Schedule requires a model"
@ -44,7 +44,6 @@ class BaseSchedule(ABC):
self.criterion = criterion self.criterion = criterion
assert optimizer is not None, "Schedule requires an optimizer" assert optimizer is not None, "Schedule requires an optimizer"
self.optimizer = optimizer self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
self.initialized = True self.initialized = True
def check_initialized(self): def check_initialized(self):
@ -66,6 +65,13 @@ class BaseSchedule(ABC):
data, label = next(self.data_iter) data, label = next(self.data_iter)
return self._move_to_device(data), self._move_to_device(label) return self._move_to_device(data), self._move_to_device(label)
def consume_batch(self):
while True:
try:
self.load_batch()
except StopIteration:
break
def _move_to_device(self, data): def _move_to_device(self, data):
if isinstance(data, ( if isinstance(data, (
tuple, tuple,
@ -87,6 +93,7 @@ class BaseSchedule(ABC):
:param mode: If True, the model will set as training mode. Otherwise, evaluation mode. :param mode: If True, the model will set as training mode. Otherwise, evaluation mode.
""" """
self.check_initialized() self.check_initialized()
self.training = mode
if mode: if mode:
self.model.train() self.model.train()
else: else:
@ -102,22 +109,11 @@ class BaseSchedule(ABC):
self.check_initialized() self.check_initialized()
self.optimizer.zero_grad() self.optimizer.zero_grad()
def get_lr(self):
"""Returns the current learning rate.
"""
if self.lr_scheduler is not None:
return self.lr_scheduler.get_lr()[0]
else:
return self.optimizer.param_groups[0]['lr']
def step(self): def step(self):
"""Updates the parameters and learning rate with the optimizer. """Updates the parameters and learning rate with the optimizer.
""" """
self.check_initialized() self.check_initialized()
self.optimizer.step() self.optimizer.step()
# update lr scheduler
if self.lr_scheduler is not None:
self.lr_scheduler.step()
@abstractmethod @abstractmethod
def forward_backward_step(self, forward_only=False, return_loss=True): def forward_backward_step(self, forward_only=False, return_loss=True):

View File

@ -81,19 +81,20 @@ class NoPipelineSchedule(BaseSchedule):
@property @property
def num_steps(self): def num_steps(self):
return len(self.dataloader) length = len(self.dataloader)
if self.training:
length -= length % self.grad_accum
return length
def initialize(self, def initialize(self,
dataloader, dataloader=None,
model, model=None,
criterion, criterion=None,
optimizer, optimizer=None):
lr_scheduler=None):
super().initialize(dataloader, super().initialize(dataloader,
model, model,
criterion, criterion,
optimizer, optimizer)
lr_scheduler=lr_scheduler)
if isinstance(self.optimizer, (ZeroRedundancyOptimizer_Level_2, if isinstance(self.optimizer, (ZeroRedundancyOptimizer_Level_2,
ZeroRedundancyOptimizer_Level_3)): ZeroRedundancyOptimizer_Level_3)):
self.use_zero_level_2_3 = True self.use_zero_level_2_3 = True
@ -147,6 +148,7 @@ class NoPipelineSchedule(BaseSchedule):
output = (output,) output = (output,)
if return_loss: if return_loss:
loss = self.criterion(*output, *label) loss = self.criterion(*output, *label)
loss /= self.grad_accum
if not forward_only: if not forward_only:
# backward # backward
@ -168,7 +170,7 @@ class NoPipelineSchedule(BaseSchedule):
loss.backward() loss.backward()
if return_loss: if return_loss:
return output, label, loss return output, label, loss * self.grad_accum
else: else:
return output, None, None return output, None, None
@ -179,7 +181,3 @@ class NoPipelineSchedule(BaseSchedule):
self._torch_amp_scaler.update() self._torch_amp_scaler.update()
else: else:
self.optimizer.step() self.optimizer.step()
# update lr scheduler
if self.lr_scheduler is not None:
self.lr_scheduler.step()

View File

@ -119,19 +119,20 @@ class PipelineSchedule(BaseSchedule):
@property @property
def num_steps(self): def num_steps(self):
return len(self.dataloader) length = len(self.dataloader)
if self.training:
length -= length % self.grad_accum
return length
def initialize(self, def initialize(self,
dataloader, dataloader=None,
model, model=None,
criterion, criterion=None,
optimizer, optimizer=None):
lr_scheduler=None):
super().initialize(dataloader, super().initialize(dataloader,
model, model,
criterion, criterion,
optimizer, optimizer)
lr_scheduler=lr_scheduler)
if isinstance(self.optimizer, (ZeroRedundancyOptimizer_Level_2, if isinstance(self.optimizer, (ZeroRedundancyOptimizer_Level_2,
ZeroRedundancyOptimizer_Level_3)): ZeroRedundancyOptimizer_Level_3)):
raise TypeError( raise TypeError(
@ -163,7 +164,7 @@ class PipelineSchedule(BaseSchedule):
if return_loss: if return_loss:
input_tensor, label = self.load_micro_batch() input_tensor, label = self.load_micro_batch()
loss_reduced = self.criterion(output_tensor, * loss_reduced = self.criterion(output_tensor, *
label) / self.num_microbatches label) / (self.num_microbatches * self.grad_accum)
return_tensors.append( return_tensors.append(
tuple((output_tensor, label[0], loss_reduced))) tuple((output_tensor, label[0], loss_reduced)))
return loss_reduced return loss_reduced
@ -309,7 +310,7 @@ class PipelineSchedule(BaseSchedule):
output, label, loss = tuple(map(list, zip(*return_tensors))) output, label, loss = tuple(map(list, zip(*return_tensors)))
return (torch.cat(output, dim=0), return (torch.cat(output, dim=0),
torch.cat(label, dim=0), torch.cat(label, dim=0),
sum(loss)) sum(loss) * self.grad_accum)
else: else:
return tuple((torch.cat(return_tensors, dim=0), None, None)) return tuple((torch.cat(return_tensors, dim=0), None, None))
else: else:

View File

@ -339,16 +339,15 @@ def initialize(config: Union[str, dict] = None,
lr_scheduler = None lr_scheduler = None
if hasattr(gpc.config, 'lr_scheduler'): if hasattr(gpc.config, 'lr_scheduler'):
if hasattr(gpc.config, 'num_steps'): # if hasattr(gpc.config, 'num_steps'):
total_steps = gpc.config.num_steps # total_steps = gpc.config.num_steps
elif hasattr(gpc.config, 'num_epochs'): # elif hasattr(gpc.config, 'num_epochs'):
total_steps = int(gpc.config.num_epochs * len(train_dataloader)) # total_steps = int(gpc.config.num_epochs * len(train_dataloader))
else: # else:
raise Exception( # raise Exception(
'Please specify training stopping criterion num_steps or num_epochs in your configuration.' # 'Please specify training stopping criterion num_steps or num_epochs in your configuration.'
) # )
lr_scheduler = build_lr_scheduler(gpc.config.lr_scheduler, optimizer, lr_scheduler = build_lr_scheduler(gpc.config.lr_scheduler, optimizer)
total_steps, len(train_dataloader))
logger.info('Learning rate scheduler is created', ranks=[0]) logger.info('Learning rate scheduler is created', ranks=[0])
# pipeline or no pipeline schedule # pipeline or no pipeline schedule

View File

@ -147,6 +147,7 @@ class Trainer:
if self.exceed_max_step(): if self.exceed_max_step():
# stop when max iter is reached # stop when max iter is reached
break break
self._engine.complete()
self._timer.stop('train-epoch', keep_in_history=True) self._timer.stop('train-epoch', keep_in_history=True)
self.call_hooks('after_train_epoch') self.call_hooks('after_train_epoch')
self._timer.reset('train-step') self._timer.reset('train-step')

View File

@ -8,10 +8,10 @@ BATCH_SIZE = 512
IMG_SIZE = 32 IMG_SIZE = 32
PATCH_SIZE = 4 PATCH_SIZE = 4
DIM = 512 DIM = 512
NUM_ATTENTION_HEADS = 8 NUM_ATTENTION_HEADS = 2
SUMMA_DIM = 2 SUMMA_DIM = 2
NUM_CLASSES = 10 NUM_CLASSES = 10
DEPTH = 6 DEPTH = 1
train_data = dict( train_data = dict(
dataset=dict( dataset=dict(
@ -127,14 +127,14 @@ hooks = [
dict(type='LogMetricByEpochHook'), dict(type='LogMetricByEpochHook'),
dict(type='Accuracy2DHook'), dict(type='Accuracy2DHook'),
dict(type='LossHook'), dict(type='LossHook'),
dict(type='TensorboardHook', log_dir='./tfb_logs'), # dict(type='TensorboardHook', log_dir='./tfb_logs'),
# dict(type='SaveCheckpointHook', interval=5, checkpoint_dir='./ckpt'), # dict(type='SaveCheckpointHook', interval=5, checkpoint_dir='./ckpt'),
# dict(type='LoadCheckpointHook', epoch=20, checkpoint_dir='./ckpt') # dict(type='LoadCheckpointHook', epoch=20, checkpoint_dir='./ckpt')
] ]
parallel = dict( parallel = dict(
pipeline=dict(size=1), pipeline=dict(size=1),
tensor=dict(size=4, mode='2d'), tensor=dict(size=1, mode='2d'),
) )
# for fp16 training # for fp16 training
@ -146,7 +146,8 @@ parallel = dict(
lr_scheduler = dict( lr_scheduler = dict(
type='LinearWarmupLR', type='LinearWarmupLR',
warmup_epochs=5 total_steps=60,
warmup_steps=5
) )
# only needed when pipeline parallel is used # only needed when pipeline parallel is used

View File

@ -17,7 +17,8 @@ def run_trainer():
criterion=criterion, criterion=criterion,
optimizer=optimizer, optimizer=optimizer,
lr_scheduler=lr_scheduler, lr_scheduler=lr_scheduler,
schedule=schedule schedule=schedule,
gradient_accumulation=5,
) )
logger.info("engine is built", ranks=[0]) logger.info("engine is built", ranks=[0])