Add gradient accumulation, fix lr scheduler

pull/17/head
1SAA 2021-11-08 15:48:27 +08:00
parent 0aa07e600c
commit 8aa21d6bc5
9 changed files with 93 additions and 70 deletions

View File

@ -235,7 +235,7 @@ def build_optimizer_wrapper(config, optimizer, model=None):
return OPTIMIZER_WRAPPERS.get_module(mod_type)(optimizer, **config_)
def build_lr_scheduler(config, optimizer, total_steps, num_steps_per_epoch):
def build_lr_scheduler(config, optimizer):
"""Returns a learning rate scheduler object of :class:`torch.optim.lr_scheduler`
constructed from `config`, `optimizer`, `total_steps` and `num_steps_per_epoch`.
@ -255,8 +255,7 @@ def build_lr_scheduler(config, optimizer, total_steps, num_steps_per_epoch):
config_ = config.copy()
mod_type = config_.pop('type')
# warmup epochs will overwrite warmup steps
if 'warmup_epochs' in config_:
warmup_epochs = config_.pop('warmup_epochs')
config_['warmup_steps'] = int(num_steps_per_epoch * warmup_epochs)
return LR_SCHEDULERS.get_module(mod_type)(optimizer, total_steps, num_steps_per_epoch=num_steps_per_epoch,
**config_)
# if 'warmup_epochs' in config_:
# warmup_epochs = config_.pop('warmup_epochs')
# config_['warmup_steps'] = int(num_steps_per_epoch * warmup_epochs)
return LR_SCHEDULERS.get_module(mod_type)(optimizer, **config_)

View File

@ -44,7 +44,9 @@ class Engine:
criterion: _Loss = None,
optimizer: Optimizer = None,
lr_scheduler: Optional[_LRScheduler] = None,
schedule: BaseSchedule = None):
schedule: BaseSchedule = None,
gradient_accumulation: int = 1,
lr_scheduler_step: str = 'epoch'):
self.train_dataloader = train_dataloader
self.test_dataloader = test_dataloader
assert model is not None, "Engine requires a model"
@ -54,6 +56,11 @@ class Engine:
self.lr_scheduler = lr_scheduler
self.schedule = schedule if schedule is not None \
else NoPipelineSchedule()
self.grad_accum_size = gradient_accumulation
self.grad_accum_step = 0
self.lr_step = 0 # for epoch updating
if lr_scheduler_step != 'epoch':
self.lr_step = 1
self._logger = get_global_dist_logger()
# build gradient handler
@ -89,9 +96,13 @@ class Engine:
self._gradient_handlers.append(handler)
self.schedule.initialize(self.train_dataloader, self.model,
self.criterion, self.optimizer,
self.lr_scheduler)
self.forward_only = False
self.criterion, self.optimizer)
self.schedule.grad_accum = self.grad_accum_size
# add for robustness
if self.optimizer is None:
self.forward_only = True
else:
self.forward_only = False
def handle_gradient(self):
"""Handles all-reduce operations of gradients across different parallel groups.
@ -116,6 +127,7 @@ class Engine:
"""Returns the neural network model in the engine.
"""
return self.model
def get_optimizer(self):
"""Returns optimizier in the engine.
"""
@ -146,7 +158,10 @@ class Engine:
def get_lr(self):
"""Gets current learning rate.
"""
return self.schedule.get_lr()
if self.lr_scheduler is not None:
return self.lr_scheduler.get_lr()[0]
else:
return self.optimizer.param_groups[0]['lr']
def step(self, return_loss=True):
"""A running step based on the schedule. Usually, it runs a training or
@ -156,15 +171,27 @@ class Engine:
:type return_loss: bool
:return: (output, lablel, loss)
"""
self.schedule.zero_grad(forward_only=self.forward_only)
if not self.forward_only and self.grad_accum_step == 0:
self.schedule.zero_grad()
output, label, loss = self.schedule.forward_backward_step(
forward_only=self.forward_only, return_loss=return_loss)
if not self.forward_only:
# all reduce gradients
self.handle_gradient()
self.schedule.step()
self.grad_accum_step += 1
if self.grad_accum_step == self.grad_accum_size:
# all reduce gradients
self.handle_gradient()
self.schedule.step()
if self.lr_scheduler is not None and self.lr_step:
self.lr_scheduler.step()
self.grad_accum_step = 0
return output, label, loss
def complete(self):
"""Updating after a epoch.
"""
self.schedule.consume_batch()
if self.lr_scheduler is not None and self.lr_step == 0:
self.lr_scheduler.step()

View File

@ -15,6 +15,8 @@ class BaseSchedule(ABC):
def __init__(self):
self.initialized = False
self.logger = get_global_dist_logger()
self.grad_accum = 1
self.training = False
@property
@abstractmethod
@ -27,15 +29,13 @@ class BaseSchedule(ABC):
dataloader=None,
model=None,
criterion=None,
optimizer=None,
lr_scheduler=None):
optimizer=None):
"""Initializes the schedule and set parameters before running.
:param dataloader: DataLoader in training or evaluation
:param model: The neural network model
:param criterion: Criterion for calculating loss
:param optimizer: Optimizer for updating the parameters
:param lr_scheduler: Learning rate scheduler in the process
"""
self.dataloader = dataloader
assert model is not None, "Schedule requires a model"
@ -44,7 +44,6 @@ class BaseSchedule(ABC):
self.criterion = criterion
assert optimizer is not None, "Schedule requires an optimizer"
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
self.initialized = True
def check_initialized(self):
@ -66,6 +65,13 @@ class BaseSchedule(ABC):
data, label = next(self.data_iter)
return self._move_to_device(data), self._move_to_device(label)
def consume_batch(self):
while True:
try:
self.load_batch()
except StopIteration:
break
def _move_to_device(self, data):
if isinstance(data, (
tuple,
@ -87,6 +93,7 @@ class BaseSchedule(ABC):
:param mode: If True, the model will set as training mode. Otherwise, evaluation mode.
"""
self.check_initialized()
self.training = mode
if mode:
self.model.train()
else:
@ -102,22 +109,11 @@ class BaseSchedule(ABC):
self.check_initialized()
self.optimizer.zero_grad()
def get_lr(self):
"""Returns the current learning rate.
"""
if self.lr_scheduler is not None:
return self.lr_scheduler.get_lr()[0]
else:
return self.optimizer.param_groups[0]['lr']
def step(self):
"""Updates the parameters and learning rate with the optimizer.
"""
self.check_initialized()
self.optimizer.step()
# update lr scheduler
if self.lr_scheduler is not None:
self.lr_scheduler.step()
@abstractmethod
def forward_backward_step(self, forward_only=False, return_loss=True):

View File

@ -81,19 +81,20 @@ class NoPipelineSchedule(BaseSchedule):
@property
def num_steps(self):
return len(self.dataloader)
length = len(self.dataloader)
if self.training:
length -= length % self.grad_accum
return length
def initialize(self,
dataloader,
model,
criterion,
optimizer,
lr_scheduler=None):
dataloader=None,
model=None,
criterion=None,
optimizer=None):
super().initialize(dataloader,
model,
criterion,
optimizer,
lr_scheduler=lr_scheduler)
optimizer)
if isinstance(self.optimizer, (ZeroRedundancyOptimizer_Level_2,
ZeroRedundancyOptimizer_Level_3)):
self.use_zero_level_2_3 = True
@ -147,6 +148,7 @@ class NoPipelineSchedule(BaseSchedule):
output = (output,)
if return_loss:
loss = self.criterion(*output, *label)
loss /= self.grad_accum
if not forward_only:
# backward
@ -168,7 +170,7 @@ class NoPipelineSchedule(BaseSchedule):
loss.backward()
if return_loss:
return output, label, loss
return output, label, loss * self.grad_accum
else:
return output, None, None
@ -179,7 +181,3 @@ class NoPipelineSchedule(BaseSchedule):
self._torch_amp_scaler.update()
else:
self.optimizer.step()
# update lr scheduler
if self.lr_scheduler is not None:
self.lr_scheduler.step()

View File

@ -119,19 +119,20 @@ class PipelineSchedule(BaseSchedule):
@property
def num_steps(self):
return len(self.dataloader)
length = len(self.dataloader)
if self.training:
length -= length % self.grad_accum
return length
def initialize(self,
dataloader,
model,
criterion,
optimizer,
lr_scheduler=None):
dataloader=None,
model=None,
criterion=None,
optimizer=None):
super().initialize(dataloader,
model,
criterion,
optimizer,
lr_scheduler=lr_scheduler)
optimizer)
if isinstance(self.optimizer, (ZeroRedundancyOptimizer_Level_2,
ZeroRedundancyOptimizer_Level_3)):
raise TypeError(
@ -163,7 +164,7 @@ class PipelineSchedule(BaseSchedule):
if return_loss:
input_tensor, label = self.load_micro_batch()
loss_reduced = self.criterion(output_tensor, *
label) / self.num_microbatches
label) / (self.num_microbatches * self.grad_accum)
return_tensors.append(
tuple((output_tensor, label[0], loss_reduced)))
return loss_reduced
@ -309,7 +310,7 @@ class PipelineSchedule(BaseSchedule):
output, label, loss = tuple(map(list, zip(*return_tensors)))
return (torch.cat(output, dim=0),
torch.cat(label, dim=0),
sum(loss))
sum(loss) * self.grad_accum)
else:
return tuple((torch.cat(return_tensors, dim=0), None, None))
else:

View File

@ -339,16 +339,15 @@ def initialize(config: Union[str, dict] = None,
lr_scheduler = None
if hasattr(gpc.config, 'lr_scheduler'):
if hasattr(gpc.config, 'num_steps'):
total_steps = gpc.config.num_steps
elif hasattr(gpc.config, 'num_epochs'):
total_steps = int(gpc.config.num_epochs * len(train_dataloader))
else:
raise Exception(
'Please specify training stopping criterion num_steps or num_epochs in your configuration.'
)
lr_scheduler = build_lr_scheduler(gpc.config.lr_scheduler, optimizer,
total_steps, len(train_dataloader))
# if hasattr(gpc.config, 'num_steps'):
# total_steps = gpc.config.num_steps
# elif hasattr(gpc.config, 'num_epochs'):
# total_steps = int(gpc.config.num_epochs * len(train_dataloader))
# else:
# raise Exception(
# 'Please specify training stopping criterion num_steps or num_epochs in your configuration.'
# )
lr_scheduler = build_lr_scheduler(gpc.config.lr_scheduler, optimizer)
logger.info('Learning rate scheduler is created', ranks=[0])
# pipeline or no pipeline schedule

View File

@ -147,6 +147,7 @@ class Trainer:
if self.exceed_max_step():
# stop when max iter is reached
break
self._engine.complete()
self._timer.stop('train-epoch', keep_in_history=True)
self.call_hooks('after_train_epoch')
self._timer.reset('train-step')

View File

@ -8,10 +8,10 @@ BATCH_SIZE = 512
IMG_SIZE = 32
PATCH_SIZE = 4
DIM = 512
NUM_ATTENTION_HEADS = 8
NUM_ATTENTION_HEADS = 2
SUMMA_DIM = 2
NUM_CLASSES = 10
DEPTH = 6
DEPTH = 1
train_data = dict(
dataset=dict(
@ -127,14 +127,14 @@ hooks = [
dict(type='LogMetricByEpochHook'),
dict(type='Accuracy2DHook'),
dict(type='LossHook'),
dict(type='TensorboardHook', log_dir='./tfb_logs'),
# dict(type='TensorboardHook', log_dir='./tfb_logs'),
# dict(type='SaveCheckpointHook', interval=5, checkpoint_dir='./ckpt'),
# dict(type='LoadCheckpointHook', epoch=20, checkpoint_dir='./ckpt')
]
parallel = dict(
pipeline=dict(size=1),
tensor=dict(size=4, mode='2d'),
tensor=dict(size=1, mode='2d'),
)
# for fp16 training
@ -146,7 +146,8 @@ parallel = dict(
lr_scheduler = dict(
type='LinearWarmupLR',
warmup_epochs=5
total_steps=60,
warmup_steps=5
)
# only needed when pipeline parallel is used

View File

@ -17,7 +17,8 @@ def run_trainer():
criterion=criterion,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
schedule=schedule
schedule=schedule,
gradient_accumulation=5,
)
logger.info("engine is built", ranks=[0])