mirror of https://github.com/hpcaitech/ColossalAI
[checkpoint]support generalized scheduler (#1222)
parent
a98319f023
commit
04537bf83e
|
@ -2,6 +2,7 @@ from torch.optim.lr_scheduler import _LRScheduler
|
|||
|
||||
|
||||
class _enable_get_lr_call:
|
||||
|
||||
def __init__(self, o):
|
||||
self.o = o
|
||||
|
||||
|
@ -33,6 +34,16 @@ class DelayerScheduler(_LRScheduler):
|
|||
self.finished = False
|
||||
super().__init__(optimizer, last_epoch)
|
||||
|
||||
def state_dict(self):
|
||||
state_dict = {key: value for key, value in self.__dict__.items() if key not in 'optimizer'}
|
||||
if isinstance(state_dict['after_scheduler'], _LRScheduler):
|
||||
state_dict['after_scheduler_type'] = type(state_dict['after_scheduler']).__name__
|
||||
state_dict['after_scheduler_dict'] = state_dict['after_scheduler'].state_dict()
|
||||
del state_dict['after_scheduler']
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
return state_dict
|
||||
|
||||
def get_lr(self):
|
||||
if self.last_epoch >= self.delay_epochs:
|
||||
if not self.finished:
|
||||
|
@ -73,6 +84,16 @@ class WarmupScheduler(_LRScheduler):
|
|||
self.finished = False
|
||||
super().__init__(optimizer, last_epoch)
|
||||
|
||||
def state_dict(self):
|
||||
state_dict = {key: value for key, value in self.__dict__.items() if key not in 'optimizer'}
|
||||
if isinstance(state_dict['after_scheduler'], _LRScheduler):
|
||||
state_dict['after_scheduler_type'] = type(state_dict['after_scheduler']).__name__
|
||||
state_dict['after_scheduler_dict'] = state_dict['after_scheduler'].state_dict()
|
||||
del state_dict['after_scheduler']
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
return state_dict
|
||||
|
||||
def get_lr(self):
|
||||
if self.last_epoch >= self.warmup_epochs:
|
||||
if not self.finished:
|
||||
|
@ -118,6 +139,16 @@ class WarmupDelayerScheduler(_LRScheduler):
|
|||
self.finished = False
|
||||
super().__init__(optimizer, last_epoch)
|
||||
|
||||
def state_dict(self):
|
||||
state_dict = {key: value for key, value in self.__dict__.items() if key not in 'optimizer'}
|
||||
if isinstance(state_dict['after_scheduler'], _LRScheduler):
|
||||
state_dict['after_scheduler_type'] = type(state_dict['after_scheduler']).__name__
|
||||
state_dict['after_scheduler_dict'] = state_dict['after_scheduler'].state_dict()
|
||||
del state_dict['after_scheduler']
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
return state_dict
|
||||
|
||||
def get_lr(self):
|
||||
if self.last_epoch >= self.warmup_epochs + self.delay_epochs:
|
||||
if not self.finished:
|
||||
|
|
|
@ -29,7 +29,6 @@ def _scan_for_pg_from_args(args, kwargs) -> ProcessGroup:
|
|||
pg = _scan_for_pg_from_args(elem, {})
|
||||
if pg is not None:
|
||||
return pg
|
||||
print(type(elem), elem, isinstance(elem, (list, tuple)))
|
||||
for k, v in kwargs:
|
||||
if isinstance(v, ColoTensor):
|
||||
pg = v.get_process_group()
|
||||
|
|
|
@ -2,10 +2,20 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.distributed as dist
|
||||
import collections
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR as _CosineAnnealingLR
|
||||
import inspect
|
||||
from colossalai.utils.model.colo_init_context import colo_state_dict
|
||||
|
||||
|
||||
def filter_dict(dict_to_filter, thing_with_kwargs):
|
||||
sig = inspect.signature(thing_with_kwargs)
|
||||
filter_keys = [param.name for param in sig.parameters.values() if param.kind == param.POSITIONAL_OR_KEYWORD]
|
||||
filter_dict = {}
|
||||
for filter_key in filter_keys:
|
||||
if filter_key in dict_to_filter:
|
||||
filter_dict[filter_key] = dict_to_filter[filter_key]
|
||||
return filter_dict
|
||||
|
||||
|
||||
def save_checkpoint(dire: str,
|
||||
epoch: int,
|
||||
model: torch.nn.Module,
|
||||
|
@ -25,9 +35,7 @@ def save_checkpoint(dire: str,
|
|||
model_state = {'epoch': epoch, 'model': colo_state_dict(model, state_dict_func=nn.Module.state_dict)}
|
||||
if dist.get_rank() == 0:
|
||||
torch.save(model_state, dire + '/epoch_{}_model.pth'.format(epoch))
|
||||
lr_scheduler_dict = lr_scheduler.state_dict()
|
||||
lr_scheduler_dict['after_scheduler'] = lr_scheduler_dict['after_scheduler'].state_dict()
|
||||
optim_state = {'epoch': epoch, 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler_dict}
|
||||
optim_state = {'epoch': epoch, 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict()}
|
||||
torch.save(optim_state, dire + '/epoch_{}_optim_rank_{}.pth'.format(epoch, dist.get_rank()))
|
||||
|
||||
|
||||
|
@ -55,8 +63,13 @@ def load_checkpoint(dire,
|
|||
optim_state = torch.load(dire + '/epoch_{}_optim_rank_{}.pth'.format(epoch, rank))
|
||||
optimizer.load_state_dict(optim_state['optimizer'])
|
||||
lr_scheduler_dict = optim_state['lr_scheduler']
|
||||
after_scheduler_dict = lr_scheduler_dict['after_scheduler']
|
||||
lr_scheduler_dict['after_scheduler'] = _CosineAnnealingLR(optimizer, after_scheduler_dict['T_max'],
|
||||
after_scheduler_dict['eta_min'],
|
||||
after_scheduler_dict['last_epoch'])
|
||||
if 'after_scheduler_type' in lr_scheduler_dict:
|
||||
after_scheduler_type = lr_scheduler_dict.pop('after_scheduler_type')
|
||||
after_scheduler_dict = lr_scheduler_dict.pop('after_scheduler_dict')
|
||||
reload_scheduler = getattr(torch.optim.lr_scheduler, after_scheduler_type)
|
||||
filtered_dict = filter_dict(after_scheduler_dict, reload_scheduler)
|
||||
lr_scheduler_dict['after_scheduler'] = reload_scheduler(
|
||||
optimizer,
|
||||
**filtered_dict,
|
||||
)
|
||||
lr_scheduler.load_state_dict(lr_scheduler_dict)
|
||||
|
|
|
@ -8,6 +8,8 @@ from functools import partial
|
|||
|
||||
import torch.multiprocessing as mp
|
||||
import torch.distributed as dist
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
from torch.optim.lr_scheduler import MultiplicativeLR
|
||||
|
||||
import colossalai
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
|
@ -102,10 +104,14 @@ def remove(path):
|
|||
raise ValueError("file {} is not a file or dir.".format(path))
|
||||
|
||||
|
||||
def run_checkpoint(init_spec_func, use_ddp, test_epoch, pg):
|
||||
def run_checkpoint(init_spec_func, use_ddp, test_epoch, test_scheduler, pg):
|
||||
num_epoch = 5
|
||||
warmup_epoch = 2
|
||||
|
||||
batch = 3
|
||||
feature = 32
|
||||
category = 16
|
||||
|
||||
train_dataloader = DummyDataLoader(batch, category, feature, length=16)
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
model = MLP(feature, category)
|
||||
|
@ -129,14 +135,25 @@ def run_checkpoint(init_spec_func, use_ddp, test_epoch, pg):
|
|||
weight_decay=0)
|
||||
optimizer_ref = torch.optim.Adam(model_ref.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
|
||||
|
||||
lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer, total_steps=20, warmup_steps=5)
|
||||
lr_scheduler_reload = CosineAnnealingWarmupLR(optimizer=optimizer_reload, total_steps=20, warmup_steps=5)
|
||||
lr_scheduler_ref = CosineAnnealingWarmupLR(optimizer=optimizer_ref, total_steps=20, warmup_steps=5)
|
||||
if test_scheduler == 'colossalai_cosine_warmup':
|
||||
lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer, total_steps=num_epoch, warmup_steps=warmup_epoch)
|
||||
lr_scheduler_reload = CosineAnnealingWarmupLR(optimizer=optimizer_reload,
|
||||
total_steps=num_epoch,
|
||||
warmup_steps=warmup_epoch)
|
||||
|
||||
elif test_scheduler == 'torch_cosine':
|
||||
lr_scheduler = CosineAnnealingLR(optimizer=optimizer, T_max=num_epoch)
|
||||
lr_scheduler_reload = CosineAnnealingLR(optimizer=optimizer_reload, T_max=num_epoch)
|
||||
|
||||
elif test_scheduler == 'torch_lambda':
|
||||
lr_lambda = lambda epoch: 0.95
|
||||
lr_scheduler = MultiplicativeLR(optimizer=optimizer, lr_lambda=lr_lambda)
|
||||
lr_scheduler_reload = MultiplicativeLR(optimizer=optimizer_reload, lr_lambda=lr_lambda)
|
||||
|
||||
init_spec_func(model, pg)
|
||||
init_spec_func(model_ref, pg)
|
||||
|
||||
for epoch in range(0, 20):
|
||||
for epoch in range(0, num_epoch):
|
||||
if epoch <= test_epoch:
|
||||
for i, image_dict in enumerate(train_dataloader):
|
||||
if use_ddp:
|
||||
|
@ -155,7 +172,6 @@ def run_checkpoint(init_spec_func, use_ddp, test_epoch, pg):
|
|||
for ref_p, p in zip(model_ref.parameters(), model.parameters()):
|
||||
ref_p.data.copy_(p)
|
||||
optimizer_ref = copy.deepcopy(optimizer)
|
||||
lr_scheduler_ref = copy.deepcopy(lr_scheduler)
|
||||
|
||||
check_param_equal(model, model_ref)
|
||||
save_checkpoint('./checkpoint', epoch, model, optimizer, lr_scheduler)
|
||||
|
@ -189,28 +205,34 @@ def run_checkpoint(init_spec_func, use_ddp, test_epoch, pg):
|
|||
check_param_equal(model_ref, model_reload)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port, use_ddp, test_epoch):
|
||||
def run_dist(rank, world_size, port, use_ddp, test_epoch, test_scheduler):
|
||||
if use_ddp and world_size == 1:
|
||||
return
|
||||
tp_world_size = world_size // 2 if use_ddp else world_size
|
||||
config = dict(parallel=dict(tensor=dict(mode="1d", size=tp_world_size),))
|
||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
pg = ProcessGroup(tp_degree=world_size)
|
||||
run_checkpoint(init_1d_row_for_linear_weight_spec, use_ddp, test_epoch, pg)
|
||||
run_checkpoint(init_1d_row_for_linear_weight_spec, use_ddp, test_epoch, test_scheduler, pg)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [4])
|
||||
@pytest.mark.parametrize('use_ddp', [True])
|
||||
@pytest.mark.parametrize('test_epoch', [1, 2, 3])
|
||||
@pytest.mark.parametrize('test_scheduler', ['colossalai_cosine_warmup', 'torch_cosine', 'torch_lambda'])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_checkpoint(world_size, use_ddp, test_epoch):
|
||||
def test_checkpoint(world_size, use_ddp, test_epoch, test_scheduler):
|
||||
if not os.path.isdir('./checkpoint'):
|
||||
os.mkdir('./checkpoint')
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port(), use_ddp=use_ddp, test_epoch=test_epoch)
|
||||
run_func = partial(run_dist,
|
||||
world_size=world_size,
|
||||
port=free_port(),
|
||||
use_ddp=use_ddp,
|
||||
test_epoch=test_epoch,
|
||||
test_scheduler=test_scheduler)
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
remove('./checkpoint')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_checkpoint(4, True, 1)
|
||||
test_checkpoint(4, True, 1, 1)
|
||||
|
|
Loading…
Reference in New Issue