mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
120 lines
3.3 KiB
120 lines
3.3 KiB
#!/usr/bin/env python
|
|
# -*- encoding: utf-8 -*-
|
|
|
|
import copy
|
|
from functools import partial
|
|
|
|
import colossalai
|
|
import pytest
|
|
import torch
|
|
import torch.multiprocessing as mp
|
|
import torch.nn as nn
|
|
from colossalai.logging import disable_existing_loggers
|
|
from colossalai.utils import checkpoint, free_port
|
|
from colossalai.zero.sharded_model import ShardedModel
|
|
from common import Net, check_grads, check_params, check_params
|
|
|
|
def checkpoint_wrapper(module, enable=True):
|
|
if enable:
|
|
module.forward = partial(checkpoint, module.forward)
|
|
return module
|
|
|
|
|
|
class Net(nn.Module):
|
|
def __init__(self, checkpoint=False) -> None:
|
|
super().__init__()
|
|
self.fc1 = nn.Linear(5, 5)
|
|
self.fc2 = nn.Linear(5, 5)
|
|
self.fc3 = nn.Linear(5, 1)
|
|
if checkpoint:
|
|
self.fc1 = checkpoint_wrapper(self.fc1)
|
|
self.layers = [
|
|
self.fc1,
|
|
self.fc2,
|
|
self.fc1,
|
|
self.fc2,
|
|
self.fc3
|
|
]
|
|
|
|
def forward(self, x):
|
|
for layer in self.layers:
|
|
x = layer(x)
|
|
return x
|
|
|
|
|
|
def run_step(model, optimizer, x, enable_autocast=False):
|
|
model.train()
|
|
optimizer.zero_grad()
|
|
with torch.cuda.amp.autocast(enabled=enable_autocast):
|
|
y = model(x)
|
|
loss = y.sum()
|
|
loss = loss.float()
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
|
|
def decode_booleans(intval, bits):
|
|
res = []
|
|
for bit in range(bits):
|
|
mask = 1 << bit
|
|
res.append((intval & mask) == mask)
|
|
return res
|
|
|
|
|
|
def check_config(checkpoint=False, fp16=False, offload=False):
|
|
model = Net(checkpoint=checkpoint).cuda()
|
|
zero_model = copy.deepcopy(model)
|
|
|
|
offload_config = {}
|
|
if offload:
|
|
offload_config['device'] = 'cpu'
|
|
zero_model = zero_model.cpu()
|
|
zero_model = ShardedModel(zero_model, mixed_precision=fp16, offload_config=offload_config)
|
|
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
|
|
zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=1e-3)
|
|
for _ in range(5):
|
|
x = torch.rand(2, 5).cuda()
|
|
run_step(model, optimizer, x, enable_autocast=fp16)
|
|
run_step(zero_model, zero_optimizer, x, enable_autocast=fp16)
|
|
check_grads(model, zero_model)
|
|
check_params(model, zero_model)
|
|
for _ in range(5):
|
|
x = torch.rand(2, 5).cuda()
|
|
run_step(model, optimizer, x, enable_autocast=False)
|
|
run_step(zero_model, zero_optimizer, x, enable_autocast=False)
|
|
check_grads(model, zero_model, loose=True)
|
|
check_params(model, zero_model, loose=True)
|
|
|
|
|
|
def run_dist(rank, world_size, port):
|
|
disable_existing_loggers()
|
|
colossalai.launch(config={},
|
|
rank=rank,
|
|
world_size=world_size,
|
|
host='localhost',
|
|
port=port,
|
|
backend='nccl')
|
|
|
|
args = ['checkpoint', 'fp16', 'offload']
|
|
|
|
def pack_args(i):
|
|
booleans = decode_booleans(i, len(args))
|
|
return {arg: booleans[idx] for idx, arg in enumerate(args)}
|
|
|
|
for j in range(2 ** len(args)):
|
|
kwargs = pack_args(j)
|
|
print(kwargs)
|
|
check_config(**kwargs)
|
|
|
|
|
|
@pytest.mark.dist
|
|
def test_zero_level_3():
|
|
world_size = 1
|
|
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
|
mp.spawn(run_func, nprocs=world_size)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
test_zero_level_3()
|