mirror of https://github.com/hpcaitech/ColossalAI
[unit test] Refactored test cases with component func (#339)
* refactored test with component func * fixed bugpull/394/head
parent
de46450461
commit
526a318032
|
@ -43,7 +43,7 @@ class DummyDataLoader(DummyDataGenerator):
|
|||
@non_distributed_component_funcs.register(name='nested_model')
|
||||
def get_training_components():
|
||||
|
||||
def model_builder(checkpoint):
|
||||
def model_builder(checkpoint=True):
|
||||
return NestedNet(checkpoint)
|
||||
|
||||
trainloader = DummyDataLoader()
|
||||
|
|
|
@ -3,12 +3,23 @@ from abc import ABC, abstractmethod
|
|||
|
||||
class DummyDataGenerator(ABC):
|
||||
|
||||
def __init__(self, length=10):
|
||||
self.length = length
|
||||
|
||||
@abstractmethod
|
||||
def generate(self):
|
||||
pass
|
||||
|
||||
def __iter__(self):
|
||||
self.step = 0
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self.step < self.length:
|
||||
self.step += 1
|
||||
return self.generate()
|
||||
else:
|
||||
raise StopIteration
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
|
|
|
@ -1,21 +1,14 @@
|
|||
import os
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
|
||||
import colossalai
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from colossalai.amp.amp_type import AMP_TYPE
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.trainer import Trainer
|
||||
from colossalai.utils import MultiTimer, free_port, get_dataloader
|
||||
from torch.optim import Adam
|
||||
from torchvision import transforms
|
||||
from torchvision.datasets import CIFAR10
|
||||
from torchvision.models import resnet18
|
||||
from colossalai.utils import MultiTimer, free_port
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
|
||||
BATCH_SIZE = 16
|
||||
IMG_SIZE = 32
|
||||
|
@ -29,40 +22,13 @@ CONFIG = dict(
|
|||
def run_trainer_no_pipeline(rank, world_size, port):
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
||||
# build model
|
||||
model = resnet18(num_classes=10)
|
||||
|
||||
# build dataloaders
|
||||
train_dataset = CIFAR10(root=Path(os.environ['DATA']),
|
||||
download=True,
|
||||
transform=transforms.Compose([
|
||||
transforms.Resize(size=(IMG_SIZE, IMG_SIZE)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
|
||||
]))
|
||||
|
||||
test_dataset = CIFAR10(root=Path(os.environ['DATA']),
|
||||
train=False,
|
||||
download=True,
|
||||
transform=transforms.Compose([
|
||||
transforms.Resize(size=(IMG_SIZE, IMG_SIZE)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
|
||||
]))
|
||||
|
||||
train_dataloader = get_dataloader(dataset=train_dataset,
|
||||
shuffle=True,
|
||||
batch_size=BATCH_SIZE,
|
||||
pin_memory=True,
|
||||
drop_last=True)
|
||||
|
||||
test_dataloader = get_dataloader(dataset=test_dataset, batch_size=BATCH_SIZE, pin_memory=True, drop_last=True)
|
||||
|
||||
# build optimizer
|
||||
optimizer = Adam(model.parameters(), lr=0.001)
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
|
||||
engine, train_dataloader, *args = colossalai.initialize(model=model,
|
||||
test_models = ['repeated_computed_layers', 'resnet18', 'nested_model']
|
||||
for name in test_models:
|
||||
get_components_func = non_distributed_component_funcs.get_callable(name)
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_builder, criterion = get_components_func()
|
||||
model = model_builder()
|
||||
optimizer = optimizer_builder(model)
|
||||
engine, train_dataloader, *_ = colossalai.initialize(model=model,
|
||||
optimizer=optimizer,
|
||||
criterion=criterion,
|
||||
train_dataloader=train_dataloader)
|
||||
|
@ -78,10 +44,9 @@ def run_trainer_no_pipeline(rank, world_size, port):
|
|||
trainer.fit(train_dataloader=train_dataloader,
|
||||
test_dataloader=test_dataloader,
|
||||
epochs=NUM_EPOCHS,
|
||||
max_steps=100,
|
||||
max_steps=5,
|
||||
display_progress=True,
|
||||
test_interval=5)
|
||||
gpc.destroy()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
|
|
|
@ -20,6 +20,7 @@ from torch.nn.utils import clip_grad_norm_
|
|||
|
||||
|
||||
class Enumerator:
|
||||
|
||||
def __init__(self, arg_names: List[str], arg_values: List[tuple]) -> None:
|
||||
self.arg_names = arg_names
|
||||
self.enums = Enumerator.all_enumerate(arg_values)
|
||||
|
@ -49,11 +50,12 @@ class Enumerator:
|
|||
|
||||
def checkpoint_wrapper(module, enable=True):
|
||||
if enable:
|
||||
module.forward = partial(checkpoint, module.forward)
|
||||
module.forward = partial(checkpoint, module.forward, False)
|
||||
return module
|
||||
|
||||
|
||||
class Net(nn.Module):
|
||||
|
||||
def __init__(self, checkpoint=False) -> None:
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(5, 5)
|
||||
|
@ -61,13 +63,7 @@ class Net(nn.Module):
|
|||
self.fc3 = nn.Linear(5, 1)
|
||||
if checkpoint:
|
||||
self.fc1 = checkpoint_wrapper(self.fc1)
|
||||
self.layers = [
|
||||
self.fc1,
|
||||
self.fc2,
|
||||
self.fc1,
|
||||
self.fc2,
|
||||
self.fc3
|
||||
]
|
||||
self.layers = [self.fc1, self.fc2, self.fc1, self.fc2, self.fc3]
|
||||
|
||||
def forward(self, x):
|
||||
for layer in self.layers:
|
||||
|
@ -158,12 +154,7 @@ def check_config(checkpoint=False, fp16=False, offload=False, norm_type=2.0):
|
|||
|
||||
def run_dist(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={},
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
host='localhost',
|
||||
port=port,
|
||||
backend='nccl')
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
||||
args = ['checkpoint', 'fp16', 'offload', 'norm_type']
|
||||
arg_values = [(False, True), (False, True), (False, True), (1.0, 2.0, float('inf'))]
|
||||
|
|
|
@ -43,23 +43,6 @@ def checkpoint_wrapper(module, enable=True):
|
|||
return module
|
||||
|
||||
|
||||
class Net(nn.Module):
|
||||
|
||||
def __init__(self, checkpoint=False) -> None:
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(5, 5)
|
||||
self.fc2 = nn.Linear(5, 5)
|
||||
self.fc3 = nn.Linear(5, 1)
|
||||
if checkpoint:
|
||||
self.fc1 = checkpoint_wrapper(self.fc1)
|
||||
self.layers = [self.fc1, self.fc2, self.fc1, self.fc2, self.fc3]
|
||||
|
||||
def forward(self, x):
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
|
||||
def allclose(tensor_a: torch.Tensor, tensor_b: torch.Tensor, loose=False) -> bool:
|
||||
if loose:
|
||||
return torch.allclose(tensor_a, tensor_b, atol=1e-2, rtol=1e-3)
|
||||
|
|
|
@ -13,7 +13,8 @@ from colossalai.utils import free_port
|
|||
from colossalai.zero.shard_utils import TensorShardStrategy
|
||||
from colossalai.zero.sharded_param import ShardedParam, ShardedTensor
|
||||
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
|
||||
from tests.test_zero_data_parallel.common import CONFIG, Net, allclose
|
||||
from tests.test_zero_data_parallel.common import CONFIG, allclose
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
|
||||
|
||||
def _run_shard_tensor(rank, world_size, port):
|
||||
|
@ -68,19 +69,20 @@ def _run_test_shard_param(rank, world_size, port):
|
|||
print(param_ref.data)
|
||||
|
||||
logger = get_dist_logger()
|
||||
model = Net()
|
||||
|
||||
# add an attribute as ca_attr to hijack the access to param.data
|
||||
for get_components_func in non_distributed_component_funcs:
|
||||
model_builder, *_ = get_components_func()
|
||||
model = model_builder(checkpoint=True)
|
||||
# add an attribute as col_attr to hijack the access to param.data
|
||||
for _, param in model.named_parameters():
|
||||
numel_ref = (param.numel() + world_size - 1) // world_size
|
||||
param.ca_attr = ShardedParam(param)
|
||||
param.ca_attr.shard()
|
||||
param_data = param.ca_attr.payload(torch.device('cpu'))
|
||||
param.col_attr = ShardedParam(param)
|
||||
param.col_attr.shard()
|
||||
param_data = param.col_attr.payload(torch.device('cpu'))
|
||||
assert (numel_ref == param_data.numel())
|
||||
|
||||
for _, param in model.named_parameters():
|
||||
param.ca_attr.gather()
|
||||
param_data = param.ca_attr.payload(torch.device('cpu'))
|
||||
param.col_attr.gather()
|
||||
param_data = param.col_attr.payload(torch.device('cpu'))
|
||||
|
||||
disable_existing_loggers([logger])
|
||||
|
||||
|
|
|
@ -3,19 +3,13 @@ import colossalai
|
|||
import copy
|
||||
import pytest
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from colossalai.zero import ShardedOptimizer
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
from colossalai.utils import free_port
|
||||
from functools import partial
|
||||
|
||||
|
||||
def check_equal(a, b):
|
||||
"""
|
||||
This function checks if two tensors are equal within tolerance
|
||||
"""
|
||||
assert torch.allclose(a.float(), b.float(), rtol=1e-4, atol=1e-3), f'a = {a}, b = {b}'
|
||||
from common import allclose
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
|
||||
|
||||
def check_completely_equal(a, b):
|
||||
|
@ -36,18 +30,16 @@ def check_sharded_param_consistency():
|
|||
pg: partition gradients and optimizer states
|
||||
|
||||
"""
|
||||
test_models = ['repeated_computed_layers', 'resnet18', 'nested_model']
|
||||
|
||||
# create layers
|
||||
oss_linear1 = nn.Linear(128, 256)
|
||||
oss_linear2 = nn.Linear(256, 512)
|
||||
for name in test_models:
|
||||
get_components_func = non_distributed_component_funcs.get_callable(name)
|
||||
model_builder, train_dataloader, *_ = get_components_func()
|
||||
|
||||
# create model
|
||||
oss_model = nn.Sequential(oss_linear1, oss_linear2)
|
||||
oss_model = model_builder(checkpoint=True).cuda().half()
|
||||
pg_model = copy.deepcopy(oss_model)
|
||||
|
||||
oss_model = oss_model.cuda().half()
|
||||
pg_model = pg_model.cuda().half()
|
||||
|
||||
# create optimizer
|
||||
oss_optimizer = torch.optim.Adam(oss_model.parameters(), lr=0.001)
|
||||
pg_optimizer = torch.optim.Adam(pg_model.parameters(), lr=0.001)
|
||||
|
@ -59,7 +51,8 @@ def check_sharded_param_consistency():
|
|||
clip_grad_norm=0.0)
|
||||
|
||||
# create
|
||||
input_data = torch.rand(32, 128).cuda().half()
|
||||
data, label = next(iter(train_dataloader))
|
||||
input_data = data.cuda().half()
|
||||
|
||||
# forward
|
||||
oss_output = oss_model(input_data)
|
||||
|
@ -73,12 +66,8 @@ def check_sharded_param_consistency():
|
|||
# check grad
|
||||
# as this param is small, the backward reduction
|
||||
# will not be fired
|
||||
oss_linear1_grad = oss_model[0].weight.grad
|
||||
oss_linear2_grad = oss_model[1].weight.grad
|
||||
pg_linear1_grad = pg_model[0].weight.grad
|
||||
pg_linear2_grad = pg_model[1].weight.grad
|
||||
check_completely_equal(oss_linear1_grad, pg_linear1_grad)
|
||||
check_completely_equal(oss_linear2_grad, pg_linear2_grad)
|
||||
for oss_param, pg_param in zip(oss_model.parameters(), pg_model.parameters()):
|
||||
check_completely_equal(oss_param.grad, pg_param.grad)
|
||||
|
||||
# step
|
||||
oss_optimizer.sync_grad()
|
||||
|
@ -89,8 +78,8 @@ def check_sharded_param_consistency():
|
|||
pg_optimizer.step()
|
||||
|
||||
# check updated param
|
||||
check_completely_equal(oss_model[0].weight, pg_model[0].weight)
|
||||
check_completely_equal(oss_model[1].weight, pg_model[1].weight)
|
||||
for oss_param, pg_param in zip(oss_model.parameters(), pg_model.parameters()):
|
||||
check_completely_equal(oss_param, pg_param)
|
||||
|
||||
|
||||
def check_sharded_optim_against_torch_ddp():
|
||||
|
@ -103,15 +92,17 @@ def check_sharded_optim_against_torch_ddp():
|
|||
differences in model output and updated parameters are within tolerance.
|
||||
"""
|
||||
|
||||
# create layer
|
||||
zero_linear1 = nn.Linear(128, 256)
|
||||
zero_linear2 = nn.Linear(256, 512)
|
||||
test_models = ['repeated_computed_layers', 'resnet18', 'nested_model']
|
||||
|
||||
for name in test_models:
|
||||
get_components_func = non_distributed_component_funcs.get_callable(name)
|
||||
model_builder, train_dataloader, *_ = get_components_func()
|
||||
|
||||
# create model
|
||||
zero_model = nn.Sequential(zero_linear1, zero_linear2)
|
||||
zero_model = model_builder(checkpoint=True).cuda()
|
||||
torch_model = copy.deepcopy(zero_model)
|
||||
|
||||
zero_model = zero_model.cuda().half()
|
||||
zero_model = zero_model.half()
|
||||
torch_model = DDP(torch_model.cuda())
|
||||
|
||||
# create optimizer
|
||||
|
@ -120,19 +111,22 @@ def check_sharded_optim_against_torch_ddp():
|
|||
# we only test stage 1 here
|
||||
# in `check_sharded_param_consistency.py`, we will test whether
|
||||
# level 1 and 2 will produce exactly the same results
|
||||
zero_optimizer = ShardedOptimizer(zero_optimizer, overlap_communication=True, initial_scale=1, clip_grad_norm=0.0)
|
||||
|
||||
zero_optimizer = ShardedOptimizer(zero_optimizer,
|
||||
overlap_communication=True,
|
||||
initial_scale=1,
|
||||
clip_grad_norm=0.0)
|
||||
torch_optimizer = torch.optim.Adam(torch_model.parameters(), lr=0.001)
|
||||
|
||||
# create
|
||||
input_data = torch.rand(32, 128).cuda()
|
||||
input_data, _ = next(iter(train_dataloader))
|
||||
input_data = input_data.cuda()
|
||||
|
||||
# zero-dp forward
|
||||
zero_output = zero_model(input_data.half())
|
||||
|
||||
# torch-ddp forward
|
||||
torch_output = torch_model(input_data)
|
||||
check_equal(zero_output, torch_output)
|
||||
allclose(zero_output, torch_output.half())
|
||||
|
||||
# zero-dp backward
|
||||
zero_optimizer.backward(zero_output.mean().float())
|
||||
|
@ -141,12 +135,8 @@ def check_sharded_optim_against_torch_ddp():
|
|||
torch_output.mean().backward()
|
||||
|
||||
# check grad
|
||||
zero_linear1_grad = zero_model[0].weight.grad
|
||||
zero_linear2_grad = zero_model[1].weight.grad
|
||||
torch_linear1_grad = torch_model.module[0].weight.grad
|
||||
torch_linear2_grad = torch_model.module[1].weight.grad
|
||||
check_equal(zero_linear1_grad, torch_linear1_grad)
|
||||
check_equal(zero_linear2_grad, torch_linear2_grad)
|
||||
for oss_param, torch_param in zip(zero_model.parameters(), torch_model.parameters()):
|
||||
allclose(oss_param.grad, torch_param.grad.half())
|
||||
|
||||
# zero-dp step
|
||||
zero_optimizer.sync_grad()
|
||||
|
@ -156,8 +146,8 @@ def check_sharded_optim_against_torch_ddp():
|
|||
torch_optimizer.step()
|
||||
|
||||
# check updated param
|
||||
check_equal(zero_model[0].weight, torch_model.module[0].weight)
|
||||
check_equal(zero_model[1].weight, torch_model.module[1].weight)
|
||||
for oss_param, torch_param in zip(zero_model.parameters(), torch_model.parameters()):
|
||||
allclose(oss_param, torch_param.half())
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
|
|
|
@ -15,6 +15,8 @@ import torch.distributed as dist
|
|||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
# this test only runs on resnet18
|
||||
# as this model has sync batch normalization
|
||||
# need to configure cudnn deterministic so that
|
||||
# randomness of convolution layers will be disabled
|
||||
colossalai.launch(config=dict(zero=dict(level=2, partition_grad=True),
|
||||
|
|
|
@ -22,8 +22,8 @@ def run_dist(rank, world_size, port):
|
|||
test_models = ['repeated_computed_layers', 'resnet18']
|
||||
for model_name in test_models:
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model, train_dataloader, test_dataloader, optimizer, criterion = get_components_func()
|
||||
model = model()
|
||||
model_builder, train_dataloader, test_dataloader, optimizer, criterion = get_components_func()
|
||||
model = model_builder()
|
||||
shard_strategy = TensorShardStrategy()
|
||||
model = model.half().cuda()
|
||||
zero_model = ShardedModelV2(deepcopy(model), shard_strategy)
|
||||
|
|
|
@ -1,119 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import copy
|
||||
from functools import partial
|
||||
|
||||
import colossalai
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.utils import checkpoint, free_port
|
||||
from colossalai.zero.sharded_model import ShardedModel
|
||||
from common import Net, check_grads, check_params, check_params
|
||||
|
||||
def checkpoint_wrapper(module, enable=True):
|
||||
if enable:
|
||||
module.forward = partial(checkpoint, module.forward)
|
||||
return module
|
||||
|
||||
|
||||
class Net(nn.Module):
|
||||
def __init__(self, checkpoint=False) -> None:
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(5, 5)
|
||||
self.fc2 = nn.Linear(5, 5)
|
||||
self.fc3 = nn.Linear(5, 1)
|
||||
if checkpoint:
|
||||
self.fc1 = checkpoint_wrapper(self.fc1)
|
||||
self.layers = [
|
||||
self.fc1,
|
||||
self.fc2,
|
||||
self.fc1,
|
||||
self.fc2,
|
||||
self.fc3
|
||||
]
|
||||
|
||||
def forward(self, x):
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
|
||||
def run_step(model, optimizer, x, enable_autocast=False):
|
||||
model.train()
|
||||
optimizer.zero_grad()
|
||||
with torch.cuda.amp.autocast(enabled=enable_autocast):
|
||||
y = model(x)
|
||||
loss = y.sum()
|
||||
loss = loss.float()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
|
||||
def decode_booleans(intval, bits):
|
||||
res = []
|
||||
for bit in range(bits):
|
||||
mask = 1 << bit
|
||||
res.append((intval & mask) == mask)
|
||||
return res
|
||||
|
||||
|
||||
def check_config(checkpoint=False, fp16=False, offload=False):
|
||||
model = Net(checkpoint=checkpoint).cuda()
|
||||
zero_model = copy.deepcopy(model)
|
||||
|
||||
offload_config = {}
|
||||
if offload:
|
||||
offload_config['device'] = 'cpu'
|
||||
zero_model = zero_model.cpu()
|
||||
zero_model = ShardedModel(zero_model, mixed_precision=fp16, offload_config=offload_config)
|
||||
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
|
||||
zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=1e-3)
|
||||
for _ in range(5):
|
||||
x = torch.rand(2, 5).cuda()
|
||||
run_step(model, optimizer, x, enable_autocast=fp16)
|
||||
run_step(zero_model, zero_optimizer, x, enable_autocast=fp16)
|
||||
check_grads(model, zero_model)
|
||||
check_params(model, zero_model)
|
||||
for _ in range(5):
|
||||
x = torch.rand(2, 5).cuda()
|
||||
run_step(model, optimizer, x, enable_autocast=False)
|
||||
run_step(zero_model, zero_optimizer, x, enable_autocast=False)
|
||||
check_grads(model, zero_model, loose=True)
|
||||
check_params(model, zero_model, loose=True)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={},
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
host='localhost',
|
||||
port=port,
|
||||
backend='nccl')
|
||||
|
||||
args = ['checkpoint', 'fp16', 'offload']
|
||||
|
||||
def pack_args(i):
|
||||
booleans = decode_booleans(i, len(args))
|
||||
return {arg: booleans[idx] for idx, arg in enumerate(args)}
|
||||
|
||||
for j in range(2 ** len(args)):
|
||||
kwargs = pack_args(j)
|
||||
print(kwargs)
|
||||
check_config(**kwargs)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
def test_zero_level_3():
|
||||
world_size = 1
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_zero_level_3()
|
|
@ -1,97 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import copy
|
||||
from functools import partial
|
||||
|
||||
import colossalai
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.utils import checkpoint, free_port
|
||||
from colossalai.zero.sharded_model import ShardedModel
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
from common import Net, check_grads_padding, check_params_padding
|
||||
|
||||
|
||||
def run_step(model, optimizer, x, enable_autocast=False):
|
||||
model.train()
|
||||
optimizer.zero_grad()
|
||||
with torch.cuda.amp.autocast(enabled=enable_autocast):
|
||||
y = model(x)
|
||||
loss = y.sum()
|
||||
loss = loss.float()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
|
||||
def decode_booleans(intval, bits):
|
||||
res = []
|
||||
for bit in range(bits):
|
||||
mask = 1 << bit
|
||||
res.append((intval & mask) == mask)
|
||||
return res
|
||||
|
||||
|
||||
def check_config(checkpoint=False, fp16=False, offload=False):
|
||||
model = Net(checkpoint=checkpoint).cuda()
|
||||
zero_model = copy.deepcopy(model)
|
||||
ddp_model = DDP(model)
|
||||
|
||||
offload_config = {}
|
||||
if offload:
|
||||
offload_config['device'] = 'cpu'
|
||||
zero_model = zero_model.cpu()
|
||||
zero_model = ShardedModel(zero_model, mixed_precision=fp16, offload_config=offload_config)
|
||||
|
||||
optimizer = torch.optim.Adam(ddp_model.parameters(), lr=1e-3)
|
||||
zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=1e-3)
|
||||
for _ in range(5):
|
||||
x = torch.rand(2, 5).cuda()
|
||||
run_step(ddp_model, optimizer, x, enable_autocast=fp16)
|
||||
run_step(zero_model, zero_optimizer, x, enable_autocast=fp16)
|
||||
check_grads_padding(ddp_model, zero_model)
|
||||
check_params_padding(ddp_model, zero_model)
|
||||
for _ in range(5):
|
||||
x = torch.rand(2, 5).cuda()
|
||||
run_step(ddp_model, optimizer, x, enable_autocast=False)
|
||||
run_step(zero_model, zero_optimizer, x, enable_autocast=False)
|
||||
check_grads_padding(ddp_model, zero_model, loose=True)
|
||||
check_params_padding(ddp_model, zero_model, loose=True)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={},
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
host='localhost',
|
||||
port=port,
|
||||
backend='nccl')
|
||||
|
||||
args = ['checkpoint', 'fp16', 'offload']
|
||||
|
||||
def pack_args(i):
|
||||
booleans = decode_booleans(i, len(args))
|
||||
return {arg: booleans[idx] for idx, arg in enumerate(args)}
|
||||
|
||||
for j in range(2 ** len(args)):
|
||||
kwargs = pack_args(j)
|
||||
if dist.get_rank() == 0:
|
||||
print(kwargs)
|
||||
check_config(**kwargs)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
def test_zero_level_3():
|
||||
world_size = 4
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_zero_level_3()
|
Loading…
Reference in New Issue