mirror of https://github.com/hpcaitech/ColossalAI
Add gradient accumulation, fix lr scheduler
parent
0aa07e600c
commit
8aa21d6bc5
|
@ -235,7 +235,7 @@ def build_optimizer_wrapper(config, optimizer, model=None):
|
||||||
return OPTIMIZER_WRAPPERS.get_module(mod_type)(optimizer, **config_)
|
return OPTIMIZER_WRAPPERS.get_module(mod_type)(optimizer, **config_)
|
||||||
|
|
||||||
|
|
||||||
def build_lr_scheduler(config, optimizer, total_steps, num_steps_per_epoch):
|
def build_lr_scheduler(config, optimizer):
|
||||||
"""Returns a learning rate scheduler object of :class:`torch.optim.lr_scheduler`
|
"""Returns a learning rate scheduler object of :class:`torch.optim.lr_scheduler`
|
||||||
constructed from `config`, `optimizer`, `total_steps` and `num_steps_per_epoch`.
|
constructed from `config`, `optimizer`, `total_steps` and `num_steps_per_epoch`.
|
||||||
|
|
||||||
|
@ -255,8 +255,7 @@ def build_lr_scheduler(config, optimizer, total_steps, num_steps_per_epoch):
|
||||||
config_ = config.copy()
|
config_ = config.copy()
|
||||||
mod_type = config_.pop('type')
|
mod_type = config_.pop('type')
|
||||||
# warmup epochs will overwrite warmup steps
|
# warmup epochs will overwrite warmup steps
|
||||||
if 'warmup_epochs' in config_:
|
# if 'warmup_epochs' in config_:
|
||||||
warmup_epochs = config_.pop('warmup_epochs')
|
# warmup_epochs = config_.pop('warmup_epochs')
|
||||||
config_['warmup_steps'] = int(num_steps_per_epoch * warmup_epochs)
|
# config_['warmup_steps'] = int(num_steps_per_epoch * warmup_epochs)
|
||||||
return LR_SCHEDULERS.get_module(mod_type)(optimizer, total_steps, num_steps_per_epoch=num_steps_per_epoch,
|
return LR_SCHEDULERS.get_module(mod_type)(optimizer, **config_)
|
||||||
**config_)
|
|
||||||
|
|
|
@ -44,7 +44,9 @@ class Engine:
|
||||||
criterion: _Loss = None,
|
criterion: _Loss = None,
|
||||||
optimizer: Optimizer = None,
|
optimizer: Optimizer = None,
|
||||||
lr_scheduler: Optional[_LRScheduler] = None,
|
lr_scheduler: Optional[_LRScheduler] = None,
|
||||||
schedule: BaseSchedule = None):
|
schedule: BaseSchedule = None,
|
||||||
|
gradient_accumulation: int = 1,
|
||||||
|
lr_scheduler_step: str = 'epoch'):
|
||||||
self.train_dataloader = train_dataloader
|
self.train_dataloader = train_dataloader
|
||||||
self.test_dataloader = test_dataloader
|
self.test_dataloader = test_dataloader
|
||||||
assert model is not None, "Engine requires a model"
|
assert model is not None, "Engine requires a model"
|
||||||
|
@ -54,6 +56,11 @@ class Engine:
|
||||||
self.lr_scheduler = lr_scheduler
|
self.lr_scheduler = lr_scheduler
|
||||||
self.schedule = schedule if schedule is not None \
|
self.schedule = schedule if schedule is not None \
|
||||||
else NoPipelineSchedule()
|
else NoPipelineSchedule()
|
||||||
|
self.grad_accum_size = gradient_accumulation
|
||||||
|
self.grad_accum_step = 0
|
||||||
|
self.lr_step = 0 # for epoch updating
|
||||||
|
if lr_scheduler_step != 'epoch':
|
||||||
|
self.lr_step = 1
|
||||||
self._logger = get_global_dist_logger()
|
self._logger = get_global_dist_logger()
|
||||||
|
|
||||||
# build gradient handler
|
# build gradient handler
|
||||||
|
@ -89,9 +96,13 @@ class Engine:
|
||||||
self._gradient_handlers.append(handler)
|
self._gradient_handlers.append(handler)
|
||||||
|
|
||||||
self.schedule.initialize(self.train_dataloader, self.model,
|
self.schedule.initialize(self.train_dataloader, self.model,
|
||||||
self.criterion, self.optimizer,
|
self.criterion, self.optimizer)
|
||||||
self.lr_scheduler)
|
self.schedule.grad_accum = self.grad_accum_size
|
||||||
self.forward_only = False
|
# add for robustness
|
||||||
|
if self.optimizer is None:
|
||||||
|
self.forward_only = True
|
||||||
|
else:
|
||||||
|
self.forward_only = False
|
||||||
|
|
||||||
def handle_gradient(self):
|
def handle_gradient(self):
|
||||||
"""Handles all-reduce operations of gradients across different parallel groups.
|
"""Handles all-reduce operations of gradients across different parallel groups.
|
||||||
|
@ -116,6 +127,7 @@ class Engine:
|
||||||
"""Returns the neural network model in the engine.
|
"""Returns the neural network model in the engine.
|
||||||
"""
|
"""
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
def get_optimizer(self):
|
def get_optimizer(self):
|
||||||
"""Returns optimizier in the engine.
|
"""Returns optimizier in the engine.
|
||||||
"""
|
"""
|
||||||
|
@ -146,7 +158,10 @@ class Engine:
|
||||||
def get_lr(self):
|
def get_lr(self):
|
||||||
"""Gets current learning rate.
|
"""Gets current learning rate.
|
||||||
"""
|
"""
|
||||||
return self.schedule.get_lr()
|
if self.lr_scheduler is not None:
|
||||||
|
return self.lr_scheduler.get_lr()[0]
|
||||||
|
else:
|
||||||
|
return self.optimizer.param_groups[0]['lr']
|
||||||
|
|
||||||
def step(self, return_loss=True):
|
def step(self, return_loss=True):
|
||||||
"""A running step based on the schedule. Usually, it runs a training or
|
"""A running step based on the schedule. Usually, it runs a training or
|
||||||
|
@ -156,15 +171,27 @@ class Engine:
|
||||||
:type return_loss: bool
|
:type return_loss: bool
|
||||||
:return: (output, lablel, loss)
|
:return: (output, lablel, loss)
|
||||||
"""
|
"""
|
||||||
self.schedule.zero_grad(forward_only=self.forward_only)
|
if not self.forward_only and self.grad_accum_step == 0:
|
||||||
|
self.schedule.zero_grad()
|
||||||
|
|
||||||
output, label, loss = self.schedule.forward_backward_step(
|
output, label, loss = self.schedule.forward_backward_step(
|
||||||
forward_only=self.forward_only, return_loss=return_loss)
|
forward_only=self.forward_only, return_loss=return_loss)
|
||||||
|
|
||||||
if not self.forward_only:
|
if not self.forward_only:
|
||||||
# all reduce gradients
|
self.grad_accum_step += 1
|
||||||
self.handle_gradient()
|
if self.grad_accum_step == self.grad_accum_size:
|
||||||
|
# all reduce gradients
|
||||||
self.schedule.step()
|
self.handle_gradient()
|
||||||
|
self.schedule.step()
|
||||||
|
if self.lr_scheduler is not None and self.lr_step:
|
||||||
|
self.lr_scheduler.step()
|
||||||
|
self.grad_accum_step = 0
|
||||||
|
|
||||||
return output, label, loss
|
return output, label, loss
|
||||||
|
|
||||||
|
def complete(self):
|
||||||
|
"""Updating after a epoch.
|
||||||
|
"""
|
||||||
|
self.schedule.consume_batch()
|
||||||
|
if self.lr_scheduler is not None and self.lr_step == 0:
|
||||||
|
self.lr_scheduler.step()
|
||||||
|
|
|
@ -15,6 +15,8 @@ class BaseSchedule(ABC):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.initialized = False
|
self.initialized = False
|
||||||
self.logger = get_global_dist_logger()
|
self.logger = get_global_dist_logger()
|
||||||
|
self.grad_accum = 1
|
||||||
|
self.training = False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
@ -27,15 +29,13 @@ class BaseSchedule(ABC):
|
||||||
dataloader=None,
|
dataloader=None,
|
||||||
model=None,
|
model=None,
|
||||||
criterion=None,
|
criterion=None,
|
||||||
optimizer=None,
|
optimizer=None):
|
||||||
lr_scheduler=None):
|
|
||||||
"""Initializes the schedule and set parameters before running.
|
"""Initializes the schedule and set parameters before running.
|
||||||
|
|
||||||
:param dataloader: DataLoader in training or evaluation
|
:param dataloader: DataLoader in training or evaluation
|
||||||
:param model: The neural network model
|
:param model: The neural network model
|
||||||
:param criterion: Criterion for calculating loss
|
:param criterion: Criterion for calculating loss
|
||||||
:param optimizer: Optimizer for updating the parameters
|
:param optimizer: Optimizer for updating the parameters
|
||||||
:param lr_scheduler: Learning rate scheduler in the process
|
|
||||||
"""
|
"""
|
||||||
self.dataloader = dataloader
|
self.dataloader = dataloader
|
||||||
assert model is not None, "Schedule requires a model"
|
assert model is not None, "Schedule requires a model"
|
||||||
|
@ -44,7 +44,6 @@ class BaseSchedule(ABC):
|
||||||
self.criterion = criterion
|
self.criterion = criterion
|
||||||
assert optimizer is not None, "Schedule requires an optimizer"
|
assert optimizer is not None, "Schedule requires an optimizer"
|
||||||
self.optimizer = optimizer
|
self.optimizer = optimizer
|
||||||
self.lr_scheduler = lr_scheduler
|
|
||||||
self.initialized = True
|
self.initialized = True
|
||||||
|
|
||||||
def check_initialized(self):
|
def check_initialized(self):
|
||||||
|
@ -66,6 +65,13 @@ class BaseSchedule(ABC):
|
||||||
data, label = next(self.data_iter)
|
data, label = next(self.data_iter)
|
||||||
return self._move_to_device(data), self._move_to_device(label)
|
return self._move_to_device(data), self._move_to_device(label)
|
||||||
|
|
||||||
|
def consume_batch(self):
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
self.load_batch()
|
||||||
|
except StopIteration:
|
||||||
|
break
|
||||||
|
|
||||||
def _move_to_device(self, data):
|
def _move_to_device(self, data):
|
||||||
if isinstance(data, (
|
if isinstance(data, (
|
||||||
tuple,
|
tuple,
|
||||||
|
@ -87,6 +93,7 @@ class BaseSchedule(ABC):
|
||||||
:param mode: If True, the model will set as training mode. Otherwise, evaluation mode.
|
:param mode: If True, the model will set as training mode. Otherwise, evaluation mode.
|
||||||
"""
|
"""
|
||||||
self.check_initialized()
|
self.check_initialized()
|
||||||
|
self.training = mode
|
||||||
if mode:
|
if mode:
|
||||||
self.model.train()
|
self.model.train()
|
||||||
else:
|
else:
|
||||||
|
@ -102,22 +109,11 @@ class BaseSchedule(ABC):
|
||||||
self.check_initialized()
|
self.check_initialized()
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
|
|
||||||
def get_lr(self):
|
|
||||||
"""Returns the current learning rate.
|
|
||||||
"""
|
|
||||||
if self.lr_scheduler is not None:
|
|
||||||
return self.lr_scheduler.get_lr()[0]
|
|
||||||
else:
|
|
||||||
return self.optimizer.param_groups[0]['lr']
|
|
||||||
|
|
||||||
def step(self):
|
def step(self):
|
||||||
"""Updates the parameters and learning rate with the optimizer.
|
"""Updates the parameters and learning rate with the optimizer.
|
||||||
"""
|
"""
|
||||||
self.check_initialized()
|
self.check_initialized()
|
||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
# update lr scheduler
|
|
||||||
if self.lr_scheduler is not None:
|
|
||||||
self.lr_scheduler.step()
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def forward_backward_step(self, forward_only=False, return_loss=True):
|
def forward_backward_step(self, forward_only=False, return_loss=True):
|
||||||
|
|
|
@ -81,19 +81,20 @@ class NoPipelineSchedule(BaseSchedule):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_steps(self):
|
def num_steps(self):
|
||||||
return len(self.dataloader)
|
length = len(self.dataloader)
|
||||||
|
if self.training:
|
||||||
|
length -= length % self.grad_accum
|
||||||
|
return length
|
||||||
|
|
||||||
def initialize(self,
|
def initialize(self,
|
||||||
dataloader,
|
dataloader=None,
|
||||||
model,
|
model=None,
|
||||||
criterion,
|
criterion=None,
|
||||||
optimizer,
|
optimizer=None):
|
||||||
lr_scheduler=None):
|
|
||||||
super().initialize(dataloader,
|
super().initialize(dataloader,
|
||||||
model,
|
model,
|
||||||
criterion,
|
criterion,
|
||||||
optimizer,
|
optimizer)
|
||||||
lr_scheduler=lr_scheduler)
|
|
||||||
if isinstance(self.optimizer, (ZeroRedundancyOptimizer_Level_2,
|
if isinstance(self.optimizer, (ZeroRedundancyOptimizer_Level_2,
|
||||||
ZeroRedundancyOptimizer_Level_3)):
|
ZeroRedundancyOptimizer_Level_3)):
|
||||||
self.use_zero_level_2_3 = True
|
self.use_zero_level_2_3 = True
|
||||||
|
@ -147,6 +148,7 @@ class NoPipelineSchedule(BaseSchedule):
|
||||||
output = (output,)
|
output = (output,)
|
||||||
if return_loss:
|
if return_loss:
|
||||||
loss = self.criterion(*output, *label)
|
loss = self.criterion(*output, *label)
|
||||||
|
loss /= self.grad_accum
|
||||||
|
|
||||||
if not forward_only:
|
if not forward_only:
|
||||||
# backward
|
# backward
|
||||||
|
@ -168,7 +170,7 @@ class NoPipelineSchedule(BaseSchedule):
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
if return_loss:
|
if return_loss:
|
||||||
return output, label, loss
|
return output, label, loss * self.grad_accum
|
||||||
else:
|
else:
|
||||||
return output, None, None
|
return output, None, None
|
||||||
|
|
||||||
|
@ -179,7 +181,3 @@ class NoPipelineSchedule(BaseSchedule):
|
||||||
self._torch_amp_scaler.update()
|
self._torch_amp_scaler.update()
|
||||||
else:
|
else:
|
||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
|
|
||||||
# update lr scheduler
|
|
||||||
if self.lr_scheduler is not None:
|
|
||||||
self.lr_scheduler.step()
|
|
||||||
|
|
|
@ -119,19 +119,20 @@ class PipelineSchedule(BaseSchedule):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_steps(self):
|
def num_steps(self):
|
||||||
return len(self.dataloader)
|
length = len(self.dataloader)
|
||||||
|
if self.training:
|
||||||
|
length -= length % self.grad_accum
|
||||||
|
return length
|
||||||
|
|
||||||
def initialize(self,
|
def initialize(self,
|
||||||
dataloader,
|
dataloader=None,
|
||||||
model,
|
model=None,
|
||||||
criterion,
|
criterion=None,
|
||||||
optimizer,
|
optimizer=None):
|
||||||
lr_scheduler=None):
|
|
||||||
super().initialize(dataloader,
|
super().initialize(dataloader,
|
||||||
model,
|
model,
|
||||||
criterion,
|
criterion,
|
||||||
optimizer,
|
optimizer)
|
||||||
lr_scheduler=lr_scheduler)
|
|
||||||
if isinstance(self.optimizer, (ZeroRedundancyOptimizer_Level_2,
|
if isinstance(self.optimizer, (ZeroRedundancyOptimizer_Level_2,
|
||||||
ZeroRedundancyOptimizer_Level_3)):
|
ZeroRedundancyOptimizer_Level_3)):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
|
@ -163,7 +164,7 @@ class PipelineSchedule(BaseSchedule):
|
||||||
if return_loss:
|
if return_loss:
|
||||||
input_tensor, label = self.load_micro_batch()
|
input_tensor, label = self.load_micro_batch()
|
||||||
loss_reduced = self.criterion(output_tensor, *
|
loss_reduced = self.criterion(output_tensor, *
|
||||||
label) / self.num_microbatches
|
label) / (self.num_microbatches * self.grad_accum)
|
||||||
return_tensors.append(
|
return_tensors.append(
|
||||||
tuple((output_tensor, label[0], loss_reduced)))
|
tuple((output_tensor, label[0], loss_reduced)))
|
||||||
return loss_reduced
|
return loss_reduced
|
||||||
|
@ -309,7 +310,7 @@ class PipelineSchedule(BaseSchedule):
|
||||||
output, label, loss = tuple(map(list, zip(*return_tensors)))
|
output, label, loss = tuple(map(list, zip(*return_tensors)))
|
||||||
return (torch.cat(output, dim=0),
|
return (torch.cat(output, dim=0),
|
||||||
torch.cat(label, dim=0),
|
torch.cat(label, dim=0),
|
||||||
sum(loss))
|
sum(loss) * self.grad_accum)
|
||||||
else:
|
else:
|
||||||
return tuple((torch.cat(return_tensors, dim=0), None, None))
|
return tuple((torch.cat(return_tensors, dim=0), None, None))
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -339,16 +339,15 @@ def initialize(config: Union[str, dict] = None,
|
||||||
|
|
||||||
lr_scheduler = None
|
lr_scheduler = None
|
||||||
if hasattr(gpc.config, 'lr_scheduler'):
|
if hasattr(gpc.config, 'lr_scheduler'):
|
||||||
if hasattr(gpc.config, 'num_steps'):
|
# if hasattr(gpc.config, 'num_steps'):
|
||||||
total_steps = gpc.config.num_steps
|
# total_steps = gpc.config.num_steps
|
||||||
elif hasattr(gpc.config, 'num_epochs'):
|
# elif hasattr(gpc.config, 'num_epochs'):
|
||||||
total_steps = int(gpc.config.num_epochs * len(train_dataloader))
|
# total_steps = int(gpc.config.num_epochs * len(train_dataloader))
|
||||||
else:
|
# else:
|
||||||
raise Exception(
|
# raise Exception(
|
||||||
'Please specify training stopping criterion num_steps or num_epochs in your configuration.'
|
# 'Please specify training stopping criterion num_steps or num_epochs in your configuration.'
|
||||||
)
|
# )
|
||||||
lr_scheduler = build_lr_scheduler(gpc.config.lr_scheduler, optimizer,
|
lr_scheduler = build_lr_scheduler(gpc.config.lr_scheduler, optimizer)
|
||||||
total_steps, len(train_dataloader))
|
|
||||||
logger.info('Learning rate scheduler is created', ranks=[0])
|
logger.info('Learning rate scheduler is created', ranks=[0])
|
||||||
|
|
||||||
# pipeline or no pipeline schedule
|
# pipeline or no pipeline schedule
|
||||||
|
|
|
@ -147,6 +147,7 @@ class Trainer:
|
||||||
if self.exceed_max_step():
|
if self.exceed_max_step():
|
||||||
# stop when max iter is reached
|
# stop when max iter is reached
|
||||||
break
|
break
|
||||||
|
self._engine.complete()
|
||||||
self._timer.stop('train-epoch', keep_in_history=True)
|
self._timer.stop('train-epoch', keep_in_history=True)
|
||||||
self.call_hooks('after_train_epoch')
|
self.call_hooks('after_train_epoch')
|
||||||
self._timer.reset('train-step')
|
self._timer.reset('train-step')
|
||||||
|
|
|
@ -8,10 +8,10 @@ BATCH_SIZE = 512
|
||||||
IMG_SIZE = 32
|
IMG_SIZE = 32
|
||||||
PATCH_SIZE = 4
|
PATCH_SIZE = 4
|
||||||
DIM = 512
|
DIM = 512
|
||||||
NUM_ATTENTION_HEADS = 8
|
NUM_ATTENTION_HEADS = 2
|
||||||
SUMMA_DIM = 2
|
SUMMA_DIM = 2
|
||||||
NUM_CLASSES = 10
|
NUM_CLASSES = 10
|
||||||
DEPTH = 6
|
DEPTH = 1
|
||||||
|
|
||||||
train_data = dict(
|
train_data = dict(
|
||||||
dataset=dict(
|
dataset=dict(
|
||||||
|
@ -127,14 +127,14 @@ hooks = [
|
||||||
dict(type='LogMetricByEpochHook'),
|
dict(type='LogMetricByEpochHook'),
|
||||||
dict(type='Accuracy2DHook'),
|
dict(type='Accuracy2DHook'),
|
||||||
dict(type='LossHook'),
|
dict(type='LossHook'),
|
||||||
dict(type='TensorboardHook', log_dir='./tfb_logs'),
|
# dict(type='TensorboardHook', log_dir='./tfb_logs'),
|
||||||
# dict(type='SaveCheckpointHook', interval=5, checkpoint_dir='./ckpt'),
|
# dict(type='SaveCheckpointHook', interval=5, checkpoint_dir='./ckpt'),
|
||||||
# dict(type='LoadCheckpointHook', epoch=20, checkpoint_dir='./ckpt')
|
# dict(type='LoadCheckpointHook', epoch=20, checkpoint_dir='./ckpt')
|
||||||
]
|
]
|
||||||
|
|
||||||
parallel = dict(
|
parallel = dict(
|
||||||
pipeline=dict(size=1),
|
pipeline=dict(size=1),
|
||||||
tensor=dict(size=4, mode='2d'),
|
tensor=dict(size=1, mode='2d'),
|
||||||
)
|
)
|
||||||
|
|
||||||
# for fp16 training
|
# for fp16 training
|
||||||
|
@ -146,7 +146,8 @@ parallel = dict(
|
||||||
|
|
||||||
lr_scheduler = dict(
|
lr_scheduler = dict(
|
||||||
type='LinearWarmupLR',
|
type='LinearWarmupLR',
|
||||||
warmup_epochs=5
|
total_steps=60,
|
||||||
|
warmup_steps=5
|
||||||
)
|
)
|
||||||
|
|
||||||
# only needed when pipeline parallel is used
|
# only needed when pipeline parallel is used
|
||||||
|
|
|
@ -17,7 +17,8 @@ def run_trainer():
|
||||||
criterion=criterion,
|
criterion=criterion,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
lr_scheduler=lr_scheduler,
|
lr_scheduler=lr_scheduler,
|
||||||
schedule=schedule
|
schedule=schedule,
|
||||||
|
gradient_accumulation=5,
|
||||||
)
|
)
|
||||||
logger.info("engine is built", ranks=[0])
|
logger.info("engine is built", ranks=[0])
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue