[checkpoint] save sharded optimizer states (#1237)

pull/1239/head
Jiarui Fang 2022-07-08 16:33:13 +08:00 committed by GitHub
parent 4a76084dc9
commit 20da6e48c8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 28 additions and 19 deletions

View File

@ -93,20 +93,17 @@ class ProcessGroup:
if idx // self._tp_degree == self._rank_idx // self._tp_degree:
self._tp_rank_list.append(rank_id)
self._tp_process_group = PYTORCHPGDICT_.get(self._tp_rank_list, 'nccl')
self._dp_process_group = PYTORCHPGDICT_.get(self._dp_rank_list, 'nccl')
self._has_cpu_groups = False
self._cpu_dp_process_group = None
self._cpu_tp_process_group = None
PYTORCHPGDICT_.get(self._tp_rank_list, 'nccl')
PYTORCHPGDICT_.get(self._dp_rank_list, 'nccl')
def set_cpu_groups(self):
if self.has_cpu_groups:
return
self.logger.info(
f'{self._rank} Gloo initialize TP group on {self._tp_rank_list}, DP group on {self._dp_rank_list}')
self._cpu_tp_process_group = PYTORCHPGDICT_.get(self._tp_rank_list, 'gloo')
self._cpu_dp_process_group = PYTORCHPGDICT_.get(self._dp_rank_list, 'gloo')
PYTORCHPGDICT_.get(self._tp_rank_list, 'gloo')
PYTORCHPGDICT_.get(self._dp_rank_list, 'gloo')
@property
def has_cpu_groups(self):
@ -152,13 +149,15 @@ class ProcessGroup:
return len(self._tp_rank_list)
def dp_process_group(self):
return self._dp_process_group
# return self._dp_process_group
return PYTORCHPGDICT_.get(self._dp_rank_list, 'nccl')
def tp_process_group(self):
return self._tp_process_group
# return self._tp_process_group
return PYTORCHPGDICT_.get(self._tp_rank_list, 'nccl')
def cpu_dp_process_group(self):
return self._cpu_dp_process_group
return PYTORCHPGDICT_.get(self._dp_rank_list, 'gloo')
def cpu_tp_process_group(self):
return self._cpu_tp_process_group
return PYTORCHPGDICT_.get(self._tp_rank_list, 'gloo')

View File

@ -32,10 +32,15 @@ def save_checkpoint(dire: str,
optimizer (torch.optim.Optimizer, optional): optimizers. Defaults to None.
lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): lr schedule. Defaults to None.
"""
model_state = {'epoch': epoch, 'model': colo_state_dict(model, state_dict_func=nn.Module.state_dict)}
model_state = {'epoch': epoch, 'model': model.state_dict()}
if dist.get_rank() == 0:
torch.save(model_state, dire + '/epoch_{}_model.pth'.format(epoch))
# TODO() If use tensor parallelism, optim_states contain SHARD ColoTensors.
# 1. convert SHARD ColoTensor to REPLICATE
# only rank 0 saves the REPLICATE tensors.
optim_state = {'epoch': epoch, 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict()}
torch.save(optim_state, dire + '/epoch_{}_optim_rank_{}.pth'.format(epoch, dist.get_rank()))

View File

@ -126,6 +126,9 @@ def run_checkpoint(init_spec_func, use_ddp, test_epoch, test_scheduler, pg):
model_reload = ColoDDP(model_reload, pg)
model_ref = ColoDDP(model_ref, pg)
init_spec_func(model, pg)
init_spec_func(model_ref, pg)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
optimizer_reload = torch.optim.Adam(model_reload.parameters(),
@ -135,23 +138,21 @@ def run_checkpoint(init_spec_func, use_ddp, test_epoch, test_scheduler, pg):
weight_decay=0)
optimizer_ref = torch.optim.Adam(model_ref.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
lr_scheduler = None
if test_scheduler == 'colossalai_cosine_warmup':
lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer, total_steps=num_epoch, warmup_steps=warmup_epoch)
lr_scheduler_reload = CosineAnnealingWarmupLR(optimizer=optimizer_reload,
total_steps=num_epoch,
warmup_steps=warmup_epoch)
elif test_scheduler == 'torch_cosine':
lr_scheduler = CosineAnnealingLR(optimizer=optimizer, T_max=num_epoch)
lr_scheduler_reload = CosineAnnealingLR(optimizer=optimizer_reload, T_max=num_epoch)
elif test_scheduler == 'torch_lambda':
lr_lambda = lambda epoch: 0.95
lr_scheduler = MultiplicativeLR(optimizer=optimizer, lr_lambda=lr_lambda)
lr_scheduler_reload = MultiplicativeLR(optimizer=optimizer_reload, lr_lambda=lr_lambda)
init_spec_func(model, pg)
init_spec_func(model_ref, pg)
else:
raise TypeError(f"{test_scheduler} is invalid")
for epoch in range(0, num_epoch):
if epoch <= test_epoch:
@ -212,7 +213,11 @@ def run_dist(rank, world_size, port, use_ddp, test_epoch, test_scheduler):
config = dict(parallel=dict(tensor=dict(mode="1d", size=tp_world_size),))
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
pg = ProcessGroup(tp_degree=world_size)
run_checkpoint(init_1d_row_for_linear_weight_spec, use_ddp, test_epoch, test_scheduler, pg)
run_checkpoint(init_1d_row_for_linear_weight_spec,
use_ddp,
test_epoch=test_epoch,
test_scheduler=test_scheduler,
pg=pg)
@pytest.mark.skip
@ -236,4 +241,4 @@ def test_checkpoint(world_size, use_ddp, test_epoch, test_scheduler):
if __name__ == '__main__':
test_checkpoint(4, True, 1, 1)
test_checkpoint(4, True, 1, "colossalai_cosine_warmup")