[zero] test gradient accumulation (#1964)

* [zero] fix memory leak for zero2

* [zero] test gradient accumulation

* [zero] remove grad clip test
pull/2038/head
HELSON 2022-11-29 13:00:30 +08:00 committed by GitHub
parent b0936e4a44
commit a1ce02d740
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 317 additions and 268 deletions

View File

@ -0,0 +1,19 @@
import random
import numpy as np
import torch
def seed_all(seed, cuda_deterministic=False):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if cuda_deterministic: # slower, more reproducible
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
else:
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True

View File

@ -1,11 +1,13 @@
import math import math
import torch import torch
import torch.distributed as dist
from torch._six import inf from torch._six import inf
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.utils import is_model_parallel_parameter from colossalai.utils import is_model_parallel_parameter
import torch.distributed as dist
def flatten(input_): def flatten(input_):
@ -99,19 +101,24 @@ def split_half_float_double(tensor_list):
return buckets return buckets
def reduce_tensor(tensor, dtype, dst_rank=None, parallel_mode=ParallelMode.DATA): def reduce_tensor(tensor, dtype=None, dst_rank=None, parallel_mode=ParallelMode.DATA):
""" """
Reduce the tensor in the data parallel process group Reduce the tensor in the data parallel process group
:param tensor: A tensor object to reduce/all-reduce :param tensor: A tensor object to reduce/all-reduce
:param dtype: The data type used in communication :param dtype: The data type used in communication
:param dst_rank: The source rank for reduce. If dst_rank is None, :param dst_rank: The source rank for reduce. If dst_rank is None,
:param parallel_mode: Communication parallel mode
all-reduce will be used instead of reduce. Default is None. all-reduce will be used instead of reduce. Default is None.
:type tensor: torch.Tensor :type tensor: torch.Tensor
:type dtype: torch.dtype :type dtype: torch.dtype, optional
:type dst_rank: int, optional :type dst_rank: int, optional
:type parallel_mode: ParallelMode, optional
""" """
# use the original dtype
if dtype is None:
dtype = tensor.dtype
# cast the data to specified dtype for reduce/all-reduce # cast the data to specified dtype for reduce/all-reduce
if tensor.dtype != dtype: if tensor.dtype != dtype:
@ -139,6 +146,7 @@ def reduce_tensor(tensor, dtype, dst_rank=None, parallel_mode=ParallelMode.DATA)
local_rank = gpc.get_local_rank(parallel_mode) local_rank = gpc.get_local_rank(parallel_mode)
if use_all_reduce or dst_rank == local_rank: if use_all_reduce or dst_rank == local_rank:
tensor.copy_(tensor_to_reduce) tensor.copy_(tensor_to_reduce)
return tensor return tensor
@ -238,7 +246,7 @@ def sync_param(flat_tensor, tensor_list):
Synchronize the flattened tensor and unflattened tensor list. When Synchronize the flattened tensor and unflattened tensor list. When
a list of tensor are flattened with `torch._utils._unflatten_dense_tensors`, a list of tensor are flattened with `torch._utils._unflatten_dense_tensors`,
a new tensor is created. Thus, the flat tensor and original tensor list do not a new tensor is created. Thus, the flat tensor and original tensor list do not
share the same memory space. This function will update the tensor list so that share the same memory space. This function will update the tensor list so that
they point to the same value. they point to the same value.
:param flat_tensor: A flat tensor obtained by calling `torch._utils._unflatten_dense_tensors` on a tensor lsit :param flat_tensor: A flat tensor obtained by calling `torch._utils._unflatten_dense_tensors` on a tensor lsit

View File

@ -44,12 +44,12 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
max_scale: int = 2**32, max_scale: int = 2**32,
# grad clipping # grad clipping
clip_grad_norm=2.0, clip_grad_norm=0.0,
verbose=False, verbose=False,
# communication # communication
reduce_bucket_size=50000000, reduce_bucket_size=1024 * 1024,
communication_dtype=torch.float16, communication_dtype=None,
overlap_communication=False, overlap_communication=False,
# stage 2 # stage 2
@ -58,7 +58,10 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
mp_parallel_mode=ParallelMode.MODEL, mp_parallel_mode=ParallelMode.MODEL,
# cpu offload # cpu offload
cpu_offload=False): cpu_offload=False,
# forced dtype
forced_dtype=None):
# TODO: add support for # TODO: add support for
# 1. fp16 master weights # 1. fp16 master weights
@ -112,6 +115,13 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
# gradient clipping # gradient clipping
self._clip_grad_norm = clip_grad_norm self._clip_grad_norm = clip_grad_norm
if forced_dtype:
for group in self._optimizer.param_groups:
group_params = group['params']
for param in group_params:
param.data = param.data.to(forced_dtype)
self._dtype = forced_dtype
# check argument conflict # check argument conflict
self._sanity_checks() self._sanity_checks()
@ -225,17 +235,21 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
fp32_partition_grad = torch.zeros_like(fp32_partition_param) fp32_partition_grad = torch.zeros_like(fp32_partition_param)
fp32_partition_param.grad = fp32_partition_grad fp32_partition_param.grad = fp32_partition_grad
# we do not need log information for optimizer, so comment them
# update the parameter with zero gradients for initialization of optimizer states # update the parameter with zero gradients for initialization of optimizer states
self._optimizer.step() # self._optimizer.step()
# remove the grad of the paramter to save memory # remove the grad of the paramter to save memory
for group_id, fp32_flat_tensor in self._fp32_flat_param_groups_of_current_rank.items(): # for group_id, fp32_flat_tensor in self._fp32_flat_param_groups_of_current_rank.items():
fp32_flat_tensor.grad = None # fp32_flat_tensor.grad = None
def _sanity_checks(self): def _sanity_checks(self):
assert torch.cuda.is_available(), 'CUDA is required' assert torch.cuda.is_available(), 'CUDA is required'
assert self._dtype == torch.float16, \ for param_group in self._optimizer.param_groups:
f'Parameters are expected to be of type torch.float16, but got {self._dtype}' group_params = param_group['params']
for param in group_params:
assert param.dtype == self._dtype, \
f"Parameters are expected to have the same dtype `{self._dtype}`, but got `{param.dtype}`"
########################################################### ###########################################################
# Backward Reduction Hook # Backward Reduction Hook
@ -389,6 +403,18 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
loss = self.loss_scale * loss loss = self.loss_scale * loss
loss.backward(retain_graph=retain_graph) loss.backward(retain_graph=retain_graph)
# finish gradient reduction
if not self._partition_grads:
self._reduce_grad_stage1()
else:
# TODO: support async comm in reduce
self._reduce_grad_stage2()
# clear reduced grads
if self._overlap_communication:
torch.cuda.synchronize()
self._param_store.clear_grads_of_previous_reduced_params()
def zero_grad(self, set_to_none=True): def zero_grad(self, set_to_none=True):
""" """
Set parameter gradients to zero. If set_to_none = True, gradient Set parameter gradients to zero. If set_to_none = True, gradient
@ -465,7 +491,7 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
# update fp16 partition updated by the current rank # update fp16 partition updated by the current rank
for group_id in range(len(self._fp16_param_groups)): for group_id in range(len(self._fp16_param_groups)):
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(rank=self._local_rank, group_id=group_id) fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(rank=self._local_rank, group_id=group_id)
fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id].to(fp16_param.device) fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id]
fp16_param.data.copy_(fp32_param) fp16_param.data.copy_(fp32_param)
# broadcast the updated model weights # broadcast the updated model weights
@ -524,22 +550,11 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
############################ ############################
def sync_grad(self): def sync_grad(self):
if not self._partition_grads:
self._reduce_grad_stage1()
else:
# TODO: support async comm in reduce
self._reduce_grad_stage2()
# update param already reduced flag # update param already reduced flag
reduction_states = self._param_store.get_param_reduction_states() reduction_states = self._param_store.get_param_reduction_states()
for tensor, state in reduction_states.items(): for tensor, state in reduction_states.items():
reduction_states[tensor] = False reduction_states[tensor] = False
# clear reduced grads
if self._overlap_communication:
torch.cuda.synchronize()
self._param_store.clear_grads_of_previous_reduced_params()
# accumulate gradient # accumulate gradient
avg_gradients = self._grad_store._averaged_gradients avg_gradients = self._grad_store._averaged_gradients
for group_id in range(self.num_param_groups): for group_id in range(self.num_param_groups):

View File

@ -0,0 +1,167 @@
import copy
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing import assert_close
import colossalai
from colossalai.testing.random import seed_all
from colossalai.utils import free_port
from colossalai.zero import LowLevelZeroOptimizer
class TestModel(nn.Module):
def __init__(self):
super(TestModel, self).__init__()
self.linear1 = nn.Linear(128, 256)
self.linear2 = nn.Linear(256, 512)
def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
return x
def exam_zero_1_2_grad_acc():
local_rank = torch.distributed.get_rank()
seed_all(2009)
# create model
zero1_model = TestModel().cuda()
zero2_model = copy.deepcopy(zero1_model)
# create optimizer
zero1_optimizer = torch.optim.Adam(zero1_model.parameters(), lr=1)
zero2_optimizer = torch.optim.Adam(zero2_model.parameters(), lr=1)
zero1_optimizer = LowLevelZeroOptimizer(zero1_optimizer,
overlap_communication=True,
initial_scale=32,
clip_grad_norm=1.0,
verbose=True)
zero2_optimizer = LowLevelZeroOptimizer(zero2_optimizer,
overlap_communication=True,
partition_grad=True,
initial_scale=32,
clip_grad_norm=1.0)
# create data
seed_all(2021 + local_rank)
input_data1 = torch.randn(32, 128).cuda()
input_data2 = torch.randn(32, 128).cuda()
def fwd_bwd_func(number, cur_data):
# zero-dp forward
zero1_output = zero1_model(cur_data)
zero2_output = zero2_model(cur_data)
assert torch.equal(zero1_output, zero2_output)
# zero-dp backward
zero1_optimizer.backward(zero1_output.sum().float())
zero2_optimizer.backward(zero2_output.sum().float())
for (n, z1p), z2p in zip(zero1_model.named_parameters(), zero2_model.parameters()):
if z2p.grad is not None:
# print(local_rank, n, z1p.shape, torch.max(z2p.grad), torch.max(torch.abs(z1p.grad - z2p.grad)))
assert torch.equal(z1p.grad, z2p.grad)
zero1_optimizer.sync_grad()
zero2_optimizer.sync_grad()
fwd_bwd_func(0, input_data1)
fwd_bwd_func(1, input_data2)
# step
zero1_optimizer.step()
zero2_optimizer.step()
# check updated param
for z1p, z2p in zip(zero1_model.parameters(), zero2_model.parameters()):
assert torch.equal(z1p.data, z2p.data)
def exam_zero_1_grad_acc():
local_rank = torch.distributed.get_rank()
grad_scale = 32
seed_all(2008)
# create models
zero_model = TestModel()
torch_model = copy.deepcopy(zero_model)
zero_model = zero_model.cuda()
torch_model = DDP(torch_model.cuda(), bucket_cap_mb=0)
# create optimizer
zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=1)
# we only test stage 1 here
# in `check_sharded_param_consistency.py`, we will test whether
# level 1 and 2 will produce exactly the same results
zero_optimizer = LowLevelZeroOptimizer(zero_optimizer,
overlap_communication=False,
initial_scale=grad_scale,
reduce_bucket_size=262144,
clip_grad_norm=1.0)
torch_optimizer = torch.optim.Adam(torch_model.parameters(), lr=1)
# create data
seed_all(2022 + local_rank)
input_data1 = torch.randn(32, 128).cuda()
input_data2 = torch.randn(32, 128).cuda()
def fwd_bwd_func(number, cur_data, check_flag):
# zero-dp forward
zero_output = zero_model(cur_data)
# torch-ddp forward
torch_output = torch_model(cur_data)
assert torch.equal(zero_output, torch_output)
# zero-dp backward
zero_optimizer.backward(zero_output.sum().float())
# torch-ddp backward
torch_output.sum().backward()
if check_flag:
# check grad
for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
unscale_grad = z1p.grad / grad_scale
# print(n, p.shape, torch.max(torch.abs(p.grad - unscale_grad)))
assert torch.equal(p.grad, unscale_grad)
zero_optimizer.sync_grad()
fwd_bwd_func(0, input_data1, True)
fwd_bwd_func(1, input_data2, False)
zero_optimizer.step()
torch.nn.utils.clip_grad_norm_(torch_model.parameters(), 1.0)
torch_optimizer.step()
# check updated param
for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
# print(n, p.shape, torch.max(p.data), torch.max(z1p.data), torch.max(torch.abs(p.data - z1p.data)))
assert_close(p.data, z1p.data)
def run_dist(rank, world_size, port):
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
exam_zero_1_grad_acc()
# exam_zero_1_2_grad_acc()
@pytest.mark.dist
def test_grad_accumulation():
world_size = 2
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_grad_accumulation()

View File

@ -1,161 +0,0 @@
import copy
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
import colossalai
from colossalai.utils import free_port
from colossalai.zero import LowLevelZeroOptimizer
def check_equal(a, b, rtol=1e-4, atol=1e-3):
"""
This function checks if two tensors are equal within tolerance
"""
assert torch.allclose(a.float(), b.float(), rtol=rtol, atol=atol), f'a = {a}, b = {b}'
def check_completely_equal(a, b):
"""
This function checks if two tensors are completely equal
"""
assert torch.all(a == b), f'a = {a}, b = {b}'
class TestModel(nn.Module):
def __init__(self):
super(TestModel, self).__init__()
self.linear1 = nn.Linear(128, 256)
self.linear2 = nn.Linear(256, 512)
def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
return x
def exam_zero_1_2_grad_clip():
# create model
zero1_model = TestModel().cuda().half()
zero2_model = copy.deepcopy(zero1_model)
# create optimizer
zero1_optimizer = torch.optim.Adam(zero1_model.parameters(), lr=0.001)
zero2_optimizer = torch.optim.Adam(zero2_model.parameters(), lr=0.001)
zero1_optimizer = LowLevelZeroOptimizer(zero1_optimizer,
overlap_communication=True,
initial_scale=32,
clip_grad_norm=1.0,
verbose=True)
zero2_optimizer = LowLevelZeroOptimizer(zero2_optimizer,
overlap_communication=True,
partition_grad=True,
initial_scale=32,
clip_grad_norm=1.0)
# create
input_data = torch.rand(32, 128).cuda().half()
# forward
zero1_output = zero1_model(input_data)
zero2_output = zero2_model(input_data)
check_completely_equal(zero1_output, zero2_output)
# backward
zero1_optimizer.backward(zero1_output.mean().float())
zero2_optimizer.backward(zero2_output.mean().float())
# check grad
# as this param is small, the backward reduction
# will not be fired
for z1p, z2p in zip(zero1_model.parameters(), zero2_model.parameters()):
check_completely_equal(z1p.grad, z2p.grad)
# step
zero1_optimizer.sync_grad()
zero2_optimizer.sync_grad()
# step
zero1_optimizer.step()
zero2_optimizer.step()
# check updated param
for z1p, z2p in zip(zero1_model.parameters(), zero2_model.parameters()):
check_completely_equal(z1p.data, z2p.data)
def exam_zero_1_grad_clip():
# create models
zero_model = TestModel()
torch_model = copy.deepcopy(zero_model)
zero_model = zero_model.cuda().half()
torch_model = DDP(torch_model.cuda())
# create optimizer
zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=0.001)
# we only test stage 1 here
# in `check_sharded_param_consistency.py`, we will test whether
# level 1 and 2 will produce exactly the same results
zero_optimizer = LowLevelZeroOptimizer(zero_optimizer,
overlap_communication=True,
initial_scale=1,
clip_grad_norm=1.0)
torch_optimizer = torch.optim.Adam(torch_model.parameters(), lr=0.001)
# create
input_data = torch.rand(32, 128).cuda()
# zero-dp forward
zero_output = zero_model(input_data.half())
# torch-ddp forward
torch_output = torch_model(input_data)
check_equal(zero_output, torch_output)
# zero-dp backward
zero_optimizer.backward(zero_output.mean().float())
# torch-ddp backward
torch_output.mean().backward()
# check grad
for p, z1p in zip(torch_model.parameters(), zero_model.parameters()):
check_equal(p.grad, z1p.grad)
# zero-dp step
zero_optimizer.sync_grad()
zero_optimizer.step()
# torch ddp step
torch.nn.utils.clip_grad_norm_(torch_model.parameters(), 1.0)
torch_optimizer.step()
# check updated param
for p, z1p in zip(torch_model.parameters(), zero_model.parameters()):
check_equal(p.data, z1p.data, atol=5e-4)
def run_dist(rank, world_size, port):
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
exam_zero_1_2_grad_clip()
exam_zero_1_grad_clip()
@pytest.mark.dist
def test_grad_clip():
world_size = 2
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_grad_clip()

View File

@ -6,27 +6,41 @@ import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing import assert_close
import colossalai import colossalai
from colossalai.testing.random import seed_all
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.zero import LowLevelZeroOptimizer from colossalai.zero import LowLevelZeroOptimizer
def check_equal(a, b): class TestModel(nn.Module):
"""
This function checks if two tensors are equal within tolerance def __init__(self):
""" super(TestModel, self).__init__()
assert torch.allclose(a.float(), b.float(), rtol=1e-4, atol=1e-3), f'a = {a}, b = {b}' self.linear1 = nn.Linear(128, 256)
self.linear2 = nn.Linear(256, 512)
def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
return x
def check_completely_equal(a, b): def half_close(a, b, loose=False):
""" rtol = None
This function checks if two tensors are completely equal atol = None
""" if loose:
assert torch.all(a == b), f'a = {a}, b = {b}' rtol = 5e-2
atol = 5e-4
a = a.detach().half()
b = b.detach().half()
assert_close(a, b, rtol=rtol, atol=atol)
def check_sharded_param_consistency(): def exam_zero_1_2():
""" """
In this test, we want to test whether zero stage 1 and 2 In this test, we want to test whether zero stage 1 and 2
deliver the same numerical results despite different communication deliver the same numerical results despite different communication
@ -37,67 +51,54 @@ def check_sharded_param_consistency():
pg: partition gradients and optimizer states pg: partition gradients and optimizer states
""" """
local_rank = torch.distributed.get_rank()
# create layers seed_all(2001)
oss_linear1 = nn.Linear(128, 256)
oss_linear2 = nn.Linear(256, 512)
# create model # create model
oss_model = nn.Sequential(oss_linear1, oss_linear2) zero1_model = TestModel().cuda()
pg_model = copy.deepcopy(oss_model) zero2_model = copy.deepcopy(zero1_model)
oss_model = oss_model.cuda().half()
pg_model = pg_model.cuda().half()
# create optimizer # create optimizer
oss_optimizer = torch.optim.Adam(oss_model.parameters(), lr=0.001) zero1_optimizer = torch.optim.Adam(zero1_model.parameters(), lr=1)
pg_optimizer = torch.optim.Adam(pg_model.parameters(), lr=0.001) zero2_optimizer = torch.optim.Adam(zero2_model.parameters(), lr=1)
oss_optimizer = LowLevelZeroOptimizer(oss_optimizer, zero1_optimizer = LowLevelZeroOptimizer(zero1_optimizer,
overlap_communication=True, overlap_communication=True,
initial_scale=1, initial_scale=128,
clip_grad_norm=0.0) verbose=True)
pg_optimizer = LowLevelZeroOptimizer(pg_optimizer, zero2_optimizer = LowLevelZeroOptimizer(zero2_optimizer,
overlap_communication=True, overlap_communication=True,
partition_grad=True, partition_grad=True,
initial_scale=1, initial_scale=128)
clip_grad_norm=0.0) # create data
seed_all(2001 + local_rank)
input_data = torch.randn(32, 128).cuda()
# create zero1_output = zero1_model(input_data)
input_data = torch.rand(32, 128).cuda().half() zero2_output = zero2_model(input_data)
assert torch.equal(zero1_output, zero2_output)
# forward # zero-dp backward
oss_output = oss_model(input_data) zero1_optimizer.backward(zero1_output.mean().float())
pg_output = pg_model(input_data) zero2_optimizer.backward(zero2_output.mean().float())
check_completely_equal(oss_output, pg_output)
# backward for (n, z1p), z2p in zip(zero1_model.named_parameters(), zero2_model.parameters()):
oss_optimizer.backward(oss_output.mean().float()) if z2p.grad is not None:
pg_optimizer.backward(pg_output.mean().float()) # print(local_rank, n, z1p.shape, torch.max(z2p.grad), torch.max(torch.abs(z1p.grad - z2p.grad)))
assert torch.equal(z1p.grad, z2p.grad)
# check grad zero1_optimizer.sync_grad()
# as this param is small, the backward reduction zero2_optimizer.sync_grad()
# will not be fired
oss_linear1_grad = oss_model[0].weight.grad
oss_linear2_grad = oss_model[1].weight.grad
pg_linear1_grad = pg_model[0].weight.grad
pg_linear2_grad = pg_model[1].weight.grad
check_completely_equal(oss_linear1_grad, pg_linear1_grad)
check_completely_equal(oss_linear2_grad, pg_linear2_grad)
# step # step
oss_optimizer.sync_grad() zero1_optimizer.step()
pg_optimizer.sync_grad() zero2_optimizer.step()
# step
oss_optimizer.step()
pg_optimizer.step()
# check updated param # check updated param
check_completely_equal(oss_model[0].weight, pg_model[0].weight) for z1p, z2p in zip(zero1_model.parameters(), zero2_model.parameters()):
check_completely_equal(oss_model[1].weight, pg_model[1].weight) assert torch.equal(z1p.data, z2p.data)
def check_sharded_optim_against_torch_ddp(): def exam_zero_1_torch_ddp():
""" """
In this test, two pairs of model and optimizers are created. In this test, two pairs of model and optimizers are created.
1. zero: use sharded optimizer and fp16 parameters 1. zero: use sharded optimizer and fp16 parameters
@ -106,20 +107,22 @@ def check_sharded_optim_against_torch_ddp():
We feed these two sets of models with the same input and check if the We feed these two sets of models with the same input and check if the
differences in model output and updated parameters are within tolerance. differences in model output and updated parameters are within tolerance.
""" """
local_rank = torch.distributed.get_rank()
seed_all(1453)
# create layer # create models
zero_linear1 = nn.Linear(128, 256) zero_model = TestModel()
zero_linear2 = nn.Linear(256, 512)
# create model
zero_model = nn.Sequential(zero_linear1, zero_linear2)
torch_model = copy.deepcopy(zero_model) torch_model = copy.deepcopy(zero_model)
zero_model = zero_model.cuda().half() zero_model = zero_model.cuda().half()
torch_model = DDP(torch_model.cuda()) # torch_model = DDP(torch_model.cuda(), bucket_cap_mb=0)
torch_model = torch_model.cuda()
# for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
# half_close(p.data, z1p.data)
# create optimizer # create optimizer
zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=0.001) zero_optimizer = torch.optim.SGD(zero_model.parameters(), lr=1)
# we only test stage 1 here # we only test stage 1 here
# in `check_sharded_param_consistency.py`, we will test whether # in `check_sharded_param_consistency.py`, we will test whether
@ -127,10 +130,11 @@ def check_sharded_optim_against_torch_ddp():
zero_optimizer = LowLevelZeroOptimizer(zero_optimizer, zero_optimizer = LowLevelZeroOptimizer(zero_optimizer,
overlap_communication=True, overlap_communication=True,
initial_scale=1, initial_scale=1,
clip_grad_norm=0.0) reduce_bucket_size=262144)
torch_optimizer = torch.optim.Adam(torch_model.parameters(), lr=0.001) torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1)
seed_all(1453 + local_rank)
# create # create
input_data = torch.rand(32, 128).cuda() input_data = torch.rand(32, 128).cuda()
@ -139,7 +143,7 @@ def check_sharded_optim_against_torch_ddp():
# torch-ddp forward # torch-ddp forward
torch_output = torch_model(input_data) torch_output = torch_model(input_data)
check_equal(zero_output, torch_output) half_close(zero_output, torch_output, loose=True)
# zero-dp backward # zero-dp backward
zero_optimizer.backward(zero_output.mean().float()) zero_optimizer.backward(zero_output.mean().float())
@ -148,12 +152,8 @@ def check_sharded_optim_against_torch_ddp():
torch_output.mean().backward() torch_output.mean().backward()
# check grad # check grad
zero_linear1_grad = zero_model[0].weight.grad for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
zero_linear2_grad = zero_model[1].weight.grad half_close(p.grad, z1p.grad, loose=True)
torch_linear1_grad = torch_model.module[0].weight.grad
torch_linear2_grad = torch_model.module[1].weight.grad
check_equal(zero_linear1_grad, torch_linear1_grad)
check_equal(zero_linear2_grad, torch_linear2_grad)
# zero-dp step # zero-dp step
zero_optimizer.sync_grad() zero_optimizer.sync_grad()
@ -163,23 +163,24 @@ def check_sharded_optim_against_torch_ddp():
torch_optimizer.step() torch_optimizer.step()
# check updated param # check updated param
check_equal(zero_model[0].weight, torch_model.module[0].weight) for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
check_equal(zero_model[1].weight, torch_model.module[1].weight) # print(n, torch.max(torch.abs(p.data - z1p.data)))
half_close(p.data, z1p.data, loose=True)
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
check_sharded_optim_against_torch_ddp() exam_zero_1_torch_ddp()
check_sharded_param_consistency() exam_zero_1_2()
@pytest.mark.dist @pytest.mark.dist
def test_sharded_optim(): def test_zero_1_2():
world_size = 2 world_size = 2
run_func = partial(run_dist, world_size=world_size, port=free_port()) run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
test_sharded_optim() test_zero_1_2()