[zero] test gradient accumulation (#1964)

* [zero] fix memory leak for zero2

* [zero] test gradient accumulation

* [zero] remove grad clip test
pull/2038/head
HELSON 2022-11-29 13:00:30 +08:00 committed by GitHub
parent b0936e4a44
commit a1ce02d740
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 317 additions and 268 deletions

View File

@ -0,0 +1,19 @@
import random
import numpy as np
import torch
def seed_all(seed, cuda_deterministic=False):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if cuda_deterministic: # slower, more reproducible
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
else:
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True

View File

@ -1,11 +1,13 @@
import math
import torch
import torch.distributed as dist
from torch._six import inf
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.utils import is_model_parallel_parameter
import torch.distributed as dist
def flatten(input_):
@ -99,19 +101,24 @@ def split_half_float_double(tensor_list):
return buckets
def reduce_tensor(tensor, dtype, dst_rank=None, parallel_mode=ParallelMode.DATA):
def reduce_tensor(tensor, dtype=None, dst_rank=None, parallel_mode=ParallelMode.DATA):
"""
Reduce the tensor in the data parallel process group
:param tensor: A tensor object to reduce/all-reduce
:param dtype: The data type used in communication
:param dst_rank: The source rank for reduce. If dst_rank is None,
:param parallel_mode: Communication parallel mode
all-reduce will be used instead of reduce. Default is None.
:type tensor: torch.Tensor
:type dtype: torch.dtype
:type dtype: torch.dtype, optional
:type dst_rank: int, optional
:type parallel_mode: ParallelMode, optional
"""
# use the original dtype
if dtype is None:
dtype = tensor.dtype
# cast the data to specified dtype for reduce/all-reduce
if tensor.dtype != dtype:
@ -139,6 +146,7 @@ def reduce_tensor(tensor, dtype, dst_rank=None, parallel_mode=ParallelMode.DATA)
local_rank = gpc.get_local_rank(parallel_mode)
if use_all_reduce or dst_rank == local_rank:
tensor.copy_(tensor_to_reduce)
return tensor
@ -238,7 +246,7 @@ def sync_param(flat_tensor, tensor_list):
Synchronize the flattened tensor and unflattened tensor list. When
a list of tensor are flattened with `torch._utils._unflatten_dense_tensors`,
a new tensor is created. Thus, the flat tensor and original tensor list do not
share the same memory space. This function will update the tensor list so that
share the same memory space. This function will update the tensor list so that
they point to the same value.
:param flat_tensor: A flat tensor obtained by calling `torch._utils._unflatten_dense_tensors` on a tensor lsit

View File

@ -44,12 +44,12 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
max_scale: int = 2**32,
# grad clipping
clip_grad_norm=2.0,
clip_grad_norm=0.0,
verbose=False,
# communication
reduce_bucket_size=50000000,
communication_dtype=torch.float16,
reduce_bucket_size=1024 * 1024,
communication_dtype=None,
overlap_communication=False,
# stage 2
@ -58,7 +58,10 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
mp_parallel_mode=ParallelMode.MODEL,
# cpu offload
cpu_offload=False):
cpu_offload=False,
# forced dtype
forced_dtype=None):
# TODO: add support for
# 1. fp16 master weights
@ -112,6 +115,13 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
# gradient clipping
self._clip_grad_norm = clip_grad_norm
if forced_dtype:
for group in self._optimizer.param_groups:
group_params = group['params']
for param in group_params:
param.data = param.data.to(forced_dtype)
self._dtype = forced_dtype
# check argument conflict
self._sanity_checks()
@ -225,17 +235,21 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
fp32_partition_grad = torch.zeros_like(fp32_partition_param)
fp32_partition_param.grad = fp32_partition_grad
# we do not need log information for optimizer, so comment them
# update the parameter with zero gradients for initialization of optimizer states
self._optimizer.step()
# self._optimizer.step()
# remove the grad of the paramter to save memory
for group_id, fp32_flat_tensor in self._fp32_flat_param_groups_of_current_rank.items():
fp32_flat_tensor.grad = None
# for group_id, fp32_flat_tensor in self._fp32_flat_param_groups_of_current_rank.items():
# fp32_flat_tensor.grad = None
def _sanity_checks(self):
assert torch.cuda.is_available(), 'CUDA is required'
assert self._dtype == torch.float16, \
f'Parameters are expected to be of type torch.float16, but got {self._dtype}'
for param_group in self._optimizer.param_groups:
group_params = param_group['params']
for param in group_params:
assert param.dtype == self._dtype, \
f"Parameters are expected to have the same dtype `{self._dtype}`, but got `{param.dtype}`"
###########################################################
# Backward Reduction Hook
@ -389,6 +403,18 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
loss = self.loss_scale * loss
loss.backward(retain_graph=retain_graph)
# finish gradient reduction
if not self._partition_grads:
self._reduce_grad_stage1()
else:
# TODO: support async comm in reduce
self._reduce_grad_stage2()
# clear reduced grads
if self._overlap_communication:
torch.cuda.synchronize()
self._param_store.clear_grads_of_previous_reduced_params()
def zero_grad(self, set_to_none=True):
"""
Set parameter gradients to zero. If set_to_none = True, gradient
@ -465,7 +491,7 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
# update fp16 partition updated by the current rank
for group_id in range(len(self._fp16_param_groups)):
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(rank=self._local_rank, group_id=group_id)
fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id].to(fp16_param.device)
fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id]
fp16_param.data.copy_(fp32_param)
# broadcast the updated model weights
@ -524,22 +550,11 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
############################
def sync_grad(self):
if not self._partition_grads:
self._reduce_grad_stage1()
else:
# TODO: support async comm in reduce
self._reduce_grad_stage2()
# update param already reduced flag
reduction_states = self._param_store.get_param_reduction_states()
for tensor, state in reduction_states.items():
reduction_states[tensor] = False
# clear reduced grads
if self._overlap_communication:
torch.cuda.synchronize()
self._param_store.clear_grads_of_previous_reduced_params()
# accumulate gradient
avg_gradients = self._grad_store._averaged_gradients
for group_id in range(self.num_param_groups):

View File

@ -0,0 +1,167 @@
import copy
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing import assert_close
import colossalai
from colossalai.testing.random import seed_all
from colossalai.utils import free_port
from colossalai.zero import LowLevelZeroOptimizer
class TestModel(nn.Module):
def __init__(self):
super(TestModel, self).__init__()
self.linear1 = nn.Linear(128, 256)
self.linear2 = nn.Linear(256, 512)
def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
return x
def exam_zero_1_2_grad_acc():
local_rank = torch.distributed.get_rank()
seed_all(2009)
# create model
zero1_model = TestModel().cuda()
zero2_model = copy.deepcopy(zero1_model)
# create optimizer
zero1_optimizer = torch.optim.Adam(zero1_model.parameters(), lr=1)
zero2_optimizer = torch.optim.Adam(zero2_model.parameters(), lr=1)
zero1_optimizer = LowLevelZeroOptimizer(zero1_optimizer,
overlap_communication=True,
initial_scale=32,
clip_grad_norm=1.0,
verbose=True)
zero2_optimizer = LowLevelZeroOptimizer(zero2_optimizer,
overlap_communication=True,
partition_grad=True,
initial_scale=32,
clip_grad_norm=1.0)
# create data
seed_all(2021 + local_rank)
input_data1 = torch.randn(32, 128).cuda()
input_data2 = torch.randn(32, 128).cuda()
def fwd_bwd_func(number, cur_data):
# zero-dp forward
zero1_output = zero1_model(cur_data)
zero2_output = zero2_model(cur_data)
assert torch.equal(zero1_output, zero2_output)
# zero-dp backward
zero1_optimizer.backward(zero1_output.sum().float())
zero2_optimizer.backward(zero2_output.sum().float())
for (n, z1p), z2p in zip(zero1_model.named_parameters(), zero2_model.parameters()):
if z2p.grad is not None:
# print(local_rank, n, z1p.shape, torch.max(z2p.grad), torch.max(torch.abs(z1p.grad - z2p.grad)))
assert torch.equal(z1p.grad, z2p.grad)
zero1_optimizer.sync_grad()
zero2_optimizer.sync_grad()
fwd_bwd_func(0, input_data1)
fwd_bwd_func(1, input_data2)
# step
zero1_optimizer.step()
zero2_optimizer.step()
# check updated param
for z1p, z2p in zip(zero1_model.parameters(), zero2_model.parameters()):
assert torch.equal(z1p.data, z2p.data)
def exam_zero_1_grad_acc():
local_rank = torch.distributed.get_rank()
grad_scale = 32
seed_all(2008)
# create models
zero_model = TestModel()
torch_model = copy.deepcopy(zero_model)
zero_model = zero_model.cuda()
torch_model = DDP(torch_model.cuda(), bucket_cap_mb=0)
# create optimizer
zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=1)
# we only test stage 1 here
# in `check_sharded_param_consistency.py`, we will test whether
# level 1 and 2 will produce exactly the same results
zero_optimizer = LowLevelZeroOptimizer(zero_optimizer,
overlap_communication=False,
initial_scale=grad_scale,
reduce_bucket_size=262144,
clip_grad_norm=1.0)
torch_optimizer = torch.optim.Adam(torch_model.parameters(), lr=1)
# create data
seed_all(2022 + local_rank)
input_data1 = torch.randn(32, 128).cuda()
input_data2 = torch.randn(32, 128).cuda()
def fwd_bwd_func(number, cur_data, check_flag):
# zero-dp forward
zero_output = zero_model(cur_data)
# torch-ddp forward
torch_output = torch_model(cur_data)
assert torch.equal(zero_output, torch_output)
# zero-dp backward
zero_optimizer.backward(zero_output.sum().float())
# torch-ddp backward
torch_output.sum().backward()
if check_flag:
# check grad
for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
unscale_grad = z1p.grad / grad_scale
# print(n, p.shape, torch.max(torch.abs(p.grad - unscale_grad)))
assert torch.equal(p.grad, unscale_grad)
zero_optimizer.sync_grad()
fwd_bwd_func(0, input_data1, True)
fwd_bwd_func(1, input_data2, False)
zero_optimizer.step()
torch.nn.utils.clip_grad_norm_(torch_model.parameters(), 1.0)
torch_optimizer.step()
# check updated param
for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
# print(n, p.shape, torch.max(p.data), torch.max(z1p.data), torch.max(torch.abs(p.data - z1p.data)))
assert_close(p.data, z1p.data)
def run_dist(rank, world_size, port):
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
exam_zero_1_grad_acc()
# exam_zero_1_2_grad_acc()
@pytest.mark.dist
def test_grad_accumulation():
world_size = 2
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_grad_accumulation()

View File

@ -1,161 +0,0 @@
import copy
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
import colossalai
from colossalai.utils import free_port
from colossalai.zero import LowLevelZeroOptimizer
def check_equal(a, b, rtol=1e-4, atol=1e-3):
"""
This function checks if two tensors are equal within tolerance
"""
assert torch.allclose(a.float(), b.float(), rtol=rtol, atol=atol), f'a = {a}, b = {b}'
def check_completely_equal(a, b):
"""
This function checks if two tensors are completely equal
"""
assert torch.all(a == b), f'a = {a}, b = {b}'
class TestModel(nn.Module):
def __init__(self):
super(TestModel, self).__init__()
self.linear1 = nn.Linear(128, 256)
self.linear2 = nn.Linear(256, 512)
def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
return x
def exam_zero_1_2_grad_clip():
# create model
zero1_model = TestModel().cuda().half()
zero2_model = copy.deepcopy(zero1_model)
# create optimizer
zero1_optimizer = torch.optim.Adam(zero1_model.parameters(), lr=0.001)
zero2_optimizer = torch.optim.Adam(zero2_model.parameters(), lr=0.001)
zero1_optimizer = LowLevelZeroOptimizer(zero1_optimizer,
overlap_communication=True,
initial_scale=32,
clip_grad_norm=1.0,
verbose=True)
zero2_optimizer = LowLevelZeroOptimizer(zero2_optimizer,
overlap_communication=True,
partition_grad=True,
initial_scale=32,
clip_grad_norm=1.0)
# create
input_data = torch.rand(32, 128).cuda().half()
# forward
zero1_output = zero1_model(input_data)
zero2_output = zero2_model(input_data)
check_completely_equal(zero1_output, zero2_output)
# backward
zero1_optimizer.backward(zero1_output.mean().float())
zero2_optimizer.backward(zero2_output.mean().float())
# check grad
# as this param is small, the backward reduction
# will not be fired
for z1p, z2p in zip(zero1_model.parameters(), zero2_model.parameters()):
check_completely_equal(z1p.grad, z2p.grad)
# step
zero1_optimizer.sync_grad()
zero2_optimizer.sync_grad()
# step
zero1_optimizer.step()
zero2_optimizer.step()
# check updated param
for z1p, z2p in zip(zero1_model.parameters(), zero2_model.parameters()):
check_completely_equal(z1p.data, z2p.data)
def exam_zero_1_grad_clip():
# create models
zero_model = TestModel()
torch_model = copy.deepcopy(zero_model)
zero_model = zero_model.cuda().half()
torch_model = DDP(torch_model.cuda())
# create optimizer
zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=0.001)
# we only test stage 1 here
# in `check_sharded_param_consistency.py`, we will test whether
# level 1 and 2 will produce exactly the same results
zero_optimizer = LowLevelZeroOptimizer(zero_optimizer,
overlap_communication=True,
initial_scale=1,
clip_grad_norm=1.0)
torch_optimizer = torch.optim.Adam(torch_model.parameters(), lr=0.001)
# create
input_data = torch.rand(32, 128).cuda()
# zero-dp forward
zero_output = zero_model(input_data.half())
# torch-ddp forward
torch_output = torch_model(input_data)
check_equal(zero_output, torch_output)
# zero-dp backward
zero_optimizer.backward(zero_output.mean().float())
# torch-ddp backward
torch_output.mean().backward()
# check grad
for p, z1p in zip(torch_model.parameters(), zero_model.parameters()):
check_equal(p.grad, z1p.grad)
# zero-dp step
zero_optimizer.sync_grad()
zero_optimizer.step()
# torch ddp step
torch.nn.utils.clip_grad_norm_(torch_model.parameters(), 1.0)
torch_optimizer.step()
# check updated param
for p, z1p in zip(torch_model.parameters(), zero_model.parameters()):
check_equal(p.data, z1p.data, atol=5e-4)
def run_dist(rank, world_size, port):
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
exam_zero_1_2_grad_clip()
exam_zero_1_grad_clip()
@pytest.mark.dist
def test_grad_clip():
world_size = 2
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_grad_clip()

View File

@ -6,27 +6,41 @@ import torch
import torch.multiprocessing as mp
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing import assert_close
import colossalai
from colossalai.testing.random import seed_all
from colossalai.utils import free_port
from colossalai.zero import LowLevelZeroOptimizer
def check_equal(a, b):
"""
This function checks if two tensors are equal within tolerance
"""
assert torch.allclose(a.float(), b.float(), rtol=1e-4, atol=1e-3), f'a = {a}, b = {b}'
class TestModel(nn.Module):
def __init__(self):
super(TestModel, self).__init__()
self.linear1 = nn.Linear(128, 256)
self.linear2 = nn.Linear(256, 512)
def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
return x
def check_completely_equal(a, b):
"""
This function checks if two tensors are completely equal
"""
assert torch.all(a == b), f'a = {a}, b = {b}'
def half_close(a, b, loose=False):
rtol = None
atol = None
if loose:
rtol = 5e-2
atol = 5e-4
a = a.detach().half()
b = b.detach().half()
assert_close(a, b, rtol=rtol, atol=atol)
def check_sharded_param_consistency():
def exam_zero_1_2():
"""
In this test, we want to test whether zero stage 1 and 2
deliver the same numerical results despite different communication
@ -37,67 +51,54 @@ def check_sharded_param_consistency():
pg: partition gradients and optimizer states
"""
# create layers
oss_linear1 = nn.Linear(128, 256)
oss_linear2 = nn.Linear(256, 512)
local_rank = torch.distributed.get_rank()
seed_all(2001)
# create model
oss_model = nn.Sequential(oss_linear1, oss_linear2)
pg_model = copy.deepcopy(oss_model)
oss_model = oss_model.cuda().half()
pg_model = pg_model.cuda().half()
zero1_model = TestModel().cuda()
zero2_model = copy.deepcopy(zero1_model)
# create optimizer
oss_optimizer = torch.optim.Adam(oss_model.parameters(), lr=0.001)
pg_optimizer = torch.optim.Adam(pg_model.parameters(), lr=0.001)
oss_optimizer = LowLevelZeroOptimizer(oss_optimizer,
overlap_communication=True,
initial_scale=1,
clip_grad_norm=0.0)
pg_optimizer = LowLevelZeroOptimizer(pg_optimizer,
overlap_communication=True,
partition_grad=True,
initial_scale=1,
clip_grad_norm=0.0)
zero1_optimizer = torch.optim.Adam(zero1_model.parameters(), lr=1)
zero2_optimizer = torch.optim.Adam(zero2_model.parameters(), lr=1)
zero1_optimizer = LowLevelZeroOptimizer(zero1_optimizer,
overlap_communication=True,
initial_scale=128,
verbose=True)
zero2_optimizer = LowLevelZeroOptimizer(zero2_optimizer,
overlap_communication=True,
partition_grad=True,
initial_scale=128)
# create data
seed_all(2001 + local_rank)
input_data = torch.randn(32, 128).cuda()
# create
input_data = torch.rand(32, 128).cuda().half()
zero1_output = zero1_model(input_data)
zero2_output = zero2_model(input_data)
assert torch.equal(zero1_output, zero2_output)
# forward
oss_output = oss_model(input_data)
pg_output = pg_model(input_data)
check_completely_equal(oss_output, pg_output)
# zero-dp backward
zero1_optimizer.backward(zero1_output.mean().float())
zero2_optimizer.backward(zero2_output.mean().float())
# backward
oss_optimizer.backward(oss_output.mean().float())
pg_optimizer.backward(pg_output.mean().float())
for (n, z1p), z2p in zip(zero1_model.named_parameters(), zero2_model.parameters()):
if z2p.grad is not None:
# print(local_rank, n, z1p.shape, torch.max(z2p.grad), torch.max(torch.abs(z1p.grad - z2p.grad)))
assert torch.equal(z1p.grad, z2p.grad)
# check grad
# as this param is small, the backward reduction
# will not be fired
oss_linear1_grad = oss_model[0].weight.grad
oss_linear2_grad = oss_model[1].weight.grad
pg_linear1_grad = pg_model[0].weight.grad
pg_linear2_grad = pg_model[1].weight.grad
check_completely_equal(oss_linear1_grad, pg_linear1_grad)
check_completely_equal(oss_linear2_grad, pg_linear2_grad)
zero1_optimizer.sync_grad()
zero2_optimizer.sync_grad()
# step
oss_optimizer.sync_grad()
pg_optimizer.sync_grad()
# step
oss_optimizer.step()
pg_optimizer.step()
zero1_optimizer.step()
zero2_optimizer.step()
# check updated param
check_completely_equal(oss_model[0].weight, pg_model[0].weight)
check_completely_equal(oss_model[1].weight, pg_model[1].weight)
for z1p, z2p in zip(zero1_model.parameters(), zero2_model.parameters()):
assert torch.equal(z1p.data, z2p.data)
def check_sharded_optim_against_torch_ddp():
def exam_zero_1_torch_ddp():
"""
In this test, two pairs of model and optimizers are created.
1. zero: use sharded optimizer and fp16 parameters
@ -106,20 +107,22 @@ def check_sharded_optim_against_torch_ddp():
We feed these two sets of models with the same input and check if the
differences in model output and updated parameters are within tolerance.
"""
local_rank = torch.distributed.get_rank()
seed_all(1453)
# create layer
zero_linear1 = nn.Linear(128, 256)
zero_linear2 = nn.Linear(256, 512)
# create model
zero_model = nn.Sequential(zero_linear1, zero_linear2)
# create models
zero_model = TestModel()
torch_model = copy.deepcopy(zero_model)
zero_model = zero_model.cuda().half()
torch_model = DDP(torch_model.cuda())
# torch_model = DDP(torch_model.cuda(), bucket_cap_mb=0)
torch_model = torch_model.cuda()
# for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
# half_close(p.data, z1p.data)
# create optimizer
zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=0.001)
zero_optimizer = torch.optim.SGD(zero_model.parameters(), lr=1)
# we only test stage 1 here
# in `check_sharded_param_consistency.py`, we will test whether
@ -127,10 +130,11 @@ def check_sharded_optim_against_torch_ddp():
zero_optimizer = LowLevelZeroOptimizer(zero_optimizer,
overlap_communication=True,
initial_scale=1,
clip_grad_norm=0.0)
reduce_bucket_size=262144)
torch_optimizer = torch.optim.Adam(torch_model.parameters(), lr=0.001)
torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1)
seed_all(1453 + local_rank)
# create
input_data = torch.rand(32, 128).cuda()
@ -139,7 +143,7 @@ def check_sharded_optim_against_torch_ddp():
# torch-ddp forward
torch_output = torch_model(input_data)
check_equal(zero_output, torch_output)
half_close(zero_output, torch_output, loose=True)
# zero-dp backward
zero_optimizer.backward(zero_output.mean().float())
@ -148,12 +152,8 @@ def check_sharded_optim_against_torch_ddp():
torch_output.mean().backward()
# check grad
zero_linear1_grad = zero_model[0].weight.grad
zero_linear2_grad = zero_model[1].weight.grad
torch_linear1_grad = torch_model.module[0].weight.grad
torch_linear2_grad = torch_model.module[1].weight.grad
check_equal(zero_linear1_grad, torch_linear1_grad)
check_equal(zero_linear2_grad, torch_linear2_grad)
for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
half_close(p.grad, z1p.grad, loose=True)
# zero-dp step
zero_optimizer.sync_grad()
@ -163,23 +163,24 @@ def check_sharded_optim_against_torch_ddp():
torch_optimizer.step()
# check updated param
check_equal(zero_model[0].weight, torch_model.module[0].weight)
check_equal(zero_model[1].weight, torch_model.module[1].weight)
for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
# print(n, torch.max(torch.abs(p.data - z1p.data)))
half_close(p.data, z1p.data, loose=True)
def run_dist(rank, world_size, port):
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
check_sharded_optim_against_torch_ddp()
check_sharded_param_consistency()
exam_zero_1_torch_ddp()
exam_zero_1_2()
@pytest.mark.dist
def test_sharded_optim():
def test_zero_1_2():
world_size = 2
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_sharded_optim()
test_zero_1_2()