mirror of https://github.com/hpcaitech/ColossalAI
[checkpoint] save sharded optimizer states (#1237)
parent
4a76084dc9
commit
20da6e48c8
|
@ -93,20 +93,17 @@ class ProcessGroup:
|
|||
if idx // self._tp_degree == self._rank_idx // self._tp_degree:
|
||||
self._tp_rank_list.append(rank_id)
|
||||
|
||||
self._tp_process_group = PYTORCHPGDICT_.get(self._tp_rank_list, 'nccl')
|
||||
self._dp_process_group = PYTORCHPGDICT_.get(self._dp_rank_list, 'nccl')
|
||||
|
||||
self._has_cpu_groups = False
|
||||
self._cpu_dp_process_group = None
|
||||
self._cpu_tp_process_group = None
|
||||
PYTORCHPGDICT_.get(self._tp_rank_list, 'nccl')
|
||||
PYTORCHPGDICT_.get(self._dp_rank_list, 'nccl')
|
||||
|
||||
def set_cpu_groups(self):
|
||||
if self.has_cpu_groups:
|
||||
return
|
||||
self.logger.info(
|
||||
f'{self._rank} Gloo initialize TP group on {self._tp_rank_list}, DP group on {self._dp_rank_list}')
|
||||
self._cpu_tp_process_group = PYTORCHPGDICT_.get(self._tp_rank_list, 'gloo')
|
||||
self._cpu_dp_process_group = PYTORCHPGDICT_.get(self._dp_rank_list, 'gloo')
|
||||
PYTORCHPGDICT_.get(self._tp_rank_list, 'gloo')
|
||||
PYTORCHPGDICT_.get(self._dp_rank_list, 'gloo')
|
||||
|
||||
@property
|
||||
def has_cpu_groups(self):
|
||||
|
@ -152,13 +149,15 @@ class ProcessGroup:
|
|||
return len(self._tp_rank_list)
|
||||
|
||||
def dp_process_group(self):
|
||||
return self._dp_process_group
|
||||
# return self._dp_process_group
|
||||
return PYTORCHPGDICT_.get(self._dp_rank_list, 'nccl')
|
||||
|
||||
def tp_process_group(self):
|
||||
return self._tp_process_group
|
||||
# return self._tp_process_group
|
||||
return PYTORCHPGDICT_.get(self._tp_rank_list, 'nccl')
|
||||
|
||||
def cpu_dp_process_group(self):
|
||||
return self._cpu_dp_process_group
|
||||
return PYTORCHPGDICT_.get(self._dp_rank_list, 'gloo')
|
||||
|
||||
def cpu_tp_process_group(self):
|
||||
return self._cpu_tp_process_group
|
||||
return PYTORCHPGDICT_.get(self._tp_rank_list, 'gloo')
|
||||
|
|
|
@ -32,10 +32,15 @@ def save_checkpoint(dire: str,
|
|||
optimizer (torch.optim.Optimizer, optional): optimizers. Defaults to None.
|
||||
lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): lr schedule. Defaults to None.
|
||||
"""
|
||||
model_state = {'epoch': epoch, 'model': colo_state_dict(model, state_dict_func=nn.Module.state_dict)}
|
||||
model_state = {'epoch': epoch, 'model': model.state_dict()}
|
||||
if dist.get_rank() == 0:
|
||||
torch.save(model_state, dire + '/epoch_{}_model.pth'.format(epoch))
|
||||
|
||||
# TODO() If use tensor parallelism, optim_states contain SHARD ColoTensors.
|
||||
# 1. convert SHARD ColoTensor to REPLICATE
|
||||
# only rank 0 saves the REPLICATE tensors.
|
||||
optim_state = {'epoch': epoch, 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict()}
|
||||
|
||||
torch.save(optim_state, dire + '/epoch_{}_optim_rank_{}.pth'.format(epoch, dist.get_rank()))
|
||||
|
||||
|
||||
|
|
|
@ -126,6 +126,9 @@ def run_checkpoint(init_spec_func, use_ddp, test_epoch, test_scheduler, pg):
|
|||
model_reload = ColoDDP(model_reload, pg)
|
||||
model_ref = ColoDDP(model_ref, pg)
|
||||
|
||||
init_spec_func(model, pg)
|
||||
init_spec_func(model_ref, pg)
|
||||
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
|
||||
optimizer_reload = torch.optim.Adam(model_reload.parameters(),
|
||||
|
@ -135,23 +138,21 @@ def run_checkpoint(init_spec_func, use_ddp, test_epoch, test_scheduler, pg):
|
|||
weight_decay=0)
|
||||
optimizer_ref = torch.optim.Adam(model_ref.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
|
||||
|
||||
lr_scheduler = None
|
||||
if test_scheduler == 'colossalai_cosine_warmup':
|
||||
lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer, total_steps=num_epoch, warmup_steps=warmup_epoch)
|
||||
lr_scheduler_reload = CosineAnnealingWarmupLR(optimizer=optimizer_reload,
|
||||
total_steps=num_epoch,
|
||||
warmup_steps=warmup_epoch)
|
||||
|
||||
elif test_scheduler == 'torch_cosine':
|
||||
lr_scheduler = CosineAnnealingLR(optimizer=optimizer, T_max=num_epoch)
|
||||
lr_scheduler_reload = CosineAnnealingLR(optimizer=optimizer_reload, T_max=num_epoch)
|
||||
|
||||
elif test_scheduler == 'torch_lambda':
|
||||
lr_lambda = lambda epoch: 0.95
|
||||
lr_scheduler = MultiplicativeLR(optimizer=optimizer, lr_lambda=lr_lambda)
|
||||
lr_scheduler_reload = MultiplicativeLR(optimizer=optimizer_reload, lr_lambda=lr_lambda)
|
||||
|
||||
init_spec_func(model, pg)
|
||||
init_spec_func(model_ref, pg)
|
||||
else:
|
||||
raise TypeError(f"{test_scheduler} is invalid")
|
||||
|
||||
for epoch in range(0, num_epoch):
|
||||
if epoch <= test_epoch:
|
||||
|
@ -212,7 +213,11 @@ def run_dist(rank, world_size, port, use_ddp, test_epoch, test_scheduler):
|
|||
config = dict(parallel=dict(tensor=dict(mode="1d", size=tp_world_size),))
|
||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
pg = ProcessGroup(tp_degree=world_size)
|
||||
run_checkpoint(init_1d_row_for_linear_weight_spec, use_ddp, test_epoch, test_scheduler, pg)
|
||||
run_checkpoint(init_1d_row_for_linear_weight_spec,
|
||||
use_ddp,
|
||||
test_epoch=test_epoch,
|
||||
test_scheduler=test_scheduler,
|
||||
pg=pg)
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
|
@ -236,4 +241,4 @@ def test_checkpoint(world_size, use_ddp, test_epoch, test_scheduler):
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_checkpoint(4, True, 1, 1)
|
||||
test_checkpoint(4, True, 1, "colossalai_cosine_warmup")
|
||||
|
|
Loading…
Reference in New Issue