mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
121 lines
4.9 KiB
121 lines
4.9 KiB
import torch
|
|
import torch.nn as nn
|
|
from colossalai.utils.checkpoint_io.meta import ParamDistMeta
|
|
from colossalai.utils.checkpoint_io.utils import build_checkpoints
|
|
from torch.optim import Adam
|
|
|
|
|
|
class DummyModel(nn.Module):
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.fc = nn.Linear(20, 1)
|
|
|
|
|
|
def test_global_model():
|
|
model = DummyModel()
|
|
model_checkpoints, optimizer_checkpoints, meta = build_checkpoints(0, model)
|
|
assert len(model_checkpoints) == 1
|
|
assert len(optimizer_checkpoints) == 0
|
|
assert meta['dist_meta'] is None
|
|
orig_state_dict = model.state_dict()
|
|
global_state_dict = model_checkpoints[0]
|
|
assert set(orig_state_dict.keys()) == set(global_state_dict.keys())
|
|
for k, v in orig_state_dict.items():
|
|
assert torch.equal(v, global_state_dict[k])
|
|
|
|
|
|
def test_global_model_shard():
|
|
model = DummyModel()
|
|
model_checkpoints, optimizer_checkpoints, meta = build_checkpoints(80, model)
|
|
assert len(model_checkpoints) == 2
|
|
assert len(optimizer_checkpoints) == 0
|
|
assert meta['dist_meta'] is None
|
|
orig_state_dict = model.state_dict()
|
|
assert set(orig_state_dict.keys()) == set(model_checkpoints[0].keys()) | set(model_checkpoints[1].keys())
|
|
assert len(set(model_checkpoints[0].keys()) & set(model_checkpoints[1].keys())) == 0
|
|
for k, v in orig_state_dict.items():
|
|
for state_dict in model_checkpoints:
|
|
if k in state_dict:
|
|
assert torch.equal(v, state_dict[k])
|
|
|
|
|
|
def test_global_optimizer():
|
|
model = DummyModel()
|
|
for p in model.parameters():
|
|
p.grad = torch.rand_like(p)
|
|
optimizer = Adam(model.parameters(), lr=1e-3)
|
|
optimizer.step()
|
|
model_checkpoints, optimizer_checkpoints, meta = build_checkpoints(0, model, optimizer)
|
|
assert len(optimizer_checkpoints) == 1
|
|
assert meta['param_to_os'] == {'fc.weight': 0, 'fc.bias': 1}
|
|
for state in meta['paired_os'].values():
|
|
for k, is_paired in state.items():
|
|
if k == 'step':
|
|
assert not is_paired
|
|
else:
|
|
assert is_paired
|
|
orig_state_dict = optimizer.state_dict()
|
|
state_dict = optimizer_checkpoints[0]
|
|
for k, orig_state in orig_state_dict['state'].items():
|
|
state = state_dict['state'][k]
|
|
for v1, v2 in zip(orig_state.values(), state.values()):
|
|
if isinstance(v2, torch.Tensor):
|
|
assert torch.equal(v1, v2)
|
|
else:
|
|
assert v2 == v2
|
|
assert orig_state_dict['param_groups'] == state_dict['param_groups']
|
|
|
|
|
|
def test_global_optimizer_shard():
|
|
model = DummyModel()
|
|
for p in model.parameters():
|
|
p.grad = torch.rand_like(p)
|
|
optimizer = Adam(model.parameters(), lr=1e-3)
|
|
optimizer.step()
|
|
model_checkpoints, optimizer_checkpoints, meta = build_checkpoints(80, model, optimizer)
|
|
assert len(optimizer_checkpoints) == 2
|
|
assert 'param_groups' in optimizer_checkpoints[0] and 'param_groups' not in optimizer_checkpoints[1]
|
|
orig_state_dict = optimizer.state_dict()
|
|
assert set(orig_state_dict['state'].keys()) == set(optimizer_checkpoints[0]['state'].keys()) | set(
|
|
optimizer_checkpoints[1]['state'].keys())
|
|
assert len(set(optimizer_checkpoints[0]['state'].keys()) & set(optimizer_checkpoints[1]['state'].keys())) == 0
|
|
for k, orig_state in orig_state_dict['state'].items():
|
|
state = optimizer_checkpoints[0]['state'][k] if k in optimizer_checkpoints[0][
|
|
'state'] else optimizer_checkpoints[1]['state'][k]
|
|
for v1, v2 in zip(orig_state.values(), state.values()):
|
|
if isinstance(v2, torch.Tensor):
|
|
assert torch.equal(v1, v2)
|
|
else:
|
|
assert v1 == v2
|
|
|
|
assert orig_state_dict['param_groups'] == optimizer_checkpoints[0]['param_groups']
|
|
|
|
|
|
def test_dist_model_optimizer():
|
|
model = DummyModel()
|
|
for p in model.parameters():
|
|
p.grad = torch.rand_like(p)
|
|
optimizer = Adam(model.parameters(), lr=1e-3)
|
|
optimizer.step()
|
|
dist_meta = {'fc.weight': ParamDistMeta(0, 2, 0, 1), 'fc.bias': ParamDistMeta(1, 2, 0, 1)}
|
|
model_checkpoints, optimizer_checkpoints, meta = build_checkpoints(0, model, optimizer, dist_meta=dist_meta)
|
|
assert dist_meta == meta['dist_meta']
|
|
assert len(model_checkpoints) == 1
|
|
assert len(optimizer_checkpoints) == 1
|
|
assert 'fc.weight' in model_checkpoints[0] and 'fc.bias' in model_checkpoints[0]
|
|
assert 0 in optimizer_checkpoints[0]['state'] and 1 in optimizer_checkpoints[0]['state']
|
|
dist_meta = {'fc.weight': ParamDistMeta(1, 2, 0, 1), 'fc.bias': ParamDistMeta(1, 2, 0, 1)}
|
|
model_checkpoints, optimizer_checkpoints, meta = build_checkpoints(0, model, optimizer, dist_meta=dist_meta)
|
|
assert dist_meta == meta['dist_meta']
|
|
assert len(model_checkpoints) == 1
|
|
assert len(optimizer_checkpoints) == 1
|
|
|
|
|
|
if __name__ == '__main__':
|
|
test_global_model()
|
|
test_global_model_shard()
|
|
test_global_optimizer()
|
|
test_global_optimizer_shard()
|
|
test_dist_model_optimizer()
|