mirror of https://github.com/hpcaitech/ColossalAI
[zero] fix memory leak for zero2 (#1955)
parent
60abd86d6a
commit
7066dfbf82
|
@ -48,7 +48,7 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
||||||
verbose=False,
|
verbose=False,
|
||||||
|
|
||||||
# communication
|
# communication
|
||||||
reduce_bucket_size=500000000,
|
reduce_bucket_size=50000000,
|
||||||
communication_dtype=torch.float16,
|
communication_dtype=torch.float16,
|
||||||
overlap_communication=False,
|
overlap_communication=False,
|
||||||
|
|
||||||
|
@ -125,14 +125,14 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
||||||
# partition these param groups for data parallel training
|
# partition these param groups for data parallel training
|
||||||
# and add buffers to parameter store for future access
|
# and add buffers to parameter store for future access
|
||||||
for group_id, param_group in enumerate(self._optimizer.param_groups):
|
for group_id, param_group in enumerate(self._optimizer.param_groups):
|
||||||
params = param_group['params']
|
group_params = param_group['params']
|
||||||
|
|
||||||
# add the fp16 params to fp16_param_groups for bookkeeping
|
# add the fp16 params to fp16_param_groups for bookkeeping
|
||||||
self._fp16_param_groups[group_id] = params
|
self._fp16_param_groups[group_id] = group_params
|
||||||
|
|
||||||
# assign parameters to ranks
|
# assign parameters to ranks
|
||||||
# the params in the list are sorted
|
# the params in the list are sorted
|
||||||
params_per_rank = self._partition_param_list(params)
|
params_per_rank = self._partition_param_list(group_params)
|
||||||
|
|
||||||
# store the mapping between param to rank
|
# store the mapping between param to rank
|
||||||
# each param should belong to only one rank
|
# each param should belong to only one rank
|
||||||
|
@ -143,14 +143,15 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
||||||
|
|
||||||
# move to cpu to make room to create the flat tensor
|
# move to cpu to make room to create the flat tensor
|
||||||
# move_tensor(params, device='cpu')
|
# move_tensor(params, device='cpu')
|
||||||
for param in params:
|
for param in group_params:
|
||||||
param.data = param.data.cpu()
|
param.data = param.data.cpu()
|
||||||
|
|
||||||
# flatten the reordered tensors
|
# flatten the reordered tensors
|
||||||
for rank in range(self._world_size):
|
for rank in range(self._world_size):
|
||||||
tensor_list = self._param_store.get_fp16_params_by_rank_group(rank, group_id)
|
tensor_list = self._param_store.get_fp16_params_by_rank_group(rank, group_id)
|
||||||
flat_tensor = flatten(tensor_list)
|
with torch.no_grad():
|
||||||
flat_tensor = flat_tensor.cuda()
|
flat_tensor = flatten(tensor_list)
|
||||||
|
flat_tensor = flat_tensor.data.cuda()
|
||||||
self._param_store.add_flat_fp16_param_by_rank_group(rank, group_id, flat_tensor)
|
self._param_store.add_flat_fp16_param_by_rank_group(rank, group_id, flat_tensor)
|
||||||
|
|
||||||
# sync parameters
|
# sync parameters
|
||||||
|
@ -161,7 +162,7 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
||||||
|
|
||||||
# create a copy of fp32 weights of the parameters for which this rank is responsible
|
# create a copy of fp32 weights of the parameters for which this rank is responsible
|
||||||
fp16_flat_current_rank = self._param_store.get_flat_fp16_param_by_rank_group(self._local_rank, group_id)
|
fp16_flat_current_rank = self._param_store.get_flat_fp16_param_by_rank_group(self._local_rank, group_id)
|
||||||
fp32_flat_current_rank = fp16_flat_current_rank.clone().float().detach()
|
fp32_flat_current_rank = fp16_flat_current_rank.float()
|
||||||
device = 'cpu' if self._cpu_offload else get_current_device()
|
device = 'cpu' if self._cpu_offload else get_current_device()
|
||||||
fp32_flat_current_rank = fp32_flat_current_rank.to(device)
|
fp32_flat_current_rank = fp32_flat_current_rank.to(device)
|
||||||
fp32_flat_current_rank.requires_grad = True
|
fp32_flat_current_rank.requires_grad = True
|
||||||
|
@ -384,7 +385,7 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
||||||
# torch.optim.Optimizer methods
|
# torch.optim.Optimizer methods
|
||||||
################################
|
################################
|
||||||
|
|
||||||
def backward(self, loss, retain_graph=True):
|
def backward(self, loss, retain_graph=False):
|
||||||
loss = self.loss_scale * loss
|
loss = self.loss_scale * loss
|
||||||
loss.backward(retain_graph=retain_graph)
|
loss.backward(retain_graph=retain_graph)
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,161 @@
|
||||||
|
import copy
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
from colossalai.utils import free_port
|
||||||
|
from colossalai.zero import LowLevelZeroOptimizer
|
||||||
|
|
||||||
|
|
||||||
|
def check_equal(a, b, rtol=1e-4, atol=1e-3):
|
||||||
|
"""
|
||||||
|
This function checks if two tensors are equal within tolerance
|
||||||
|
"""
|
||||||
|
assert torch.allclose(a.float(), b.float(), rtol=rtol, atol=atol), f'a = {a}, b = {b}'
|
||||||
|
|
||||||
|
|
||||||
|
def check_completely_equal(a, b):
|
||||||
|
"""
|
||||||
|
This function checks if two tensors are completely equal
|
||||||
|
"""
|
||||||
|
assert torch.all(a == b), f'a = {a}, b = {b}'
|
||||||
|
|
||||||
|
|
||||||
|
class TestModel(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(TestModel, self).__init__()
|
||||||
|
self.linear1 = nn.Linear(128, 256)
|
||||||
|
self.linear2 = nn.Linear(256, 512)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.linear1(x)
|
||||||
|
x = self.linear2(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def exam_zero_1_2_grad_clip():
|
||||||
|
# create model
|
||||||
|
zero1_model = TestModel().cuda().half()
|
||||||
|
zero2_model = copy.deepcopy(zero1_model)
|
||||||
|
|
||||||
|
# create optimizer
|
||||||
|
zero1_optimizer = torch.optim.Adam(zero1_model.parameters(), lr=0.001)
|
||||||
|
zero2_optimizer = torch.optim.Adam(zero2_model.parameters(), lr=0.001)
|
||||||
|
zero1_optimizer = LowLevelZeroOptimizer(zero1_optimizer,
|
||||||
|
overlap_communication=True,
|
||||||
|
initial_scale=32,
|
||||||
|
clip_grad_norm=1.0,
|
||||||
|
verbose=True)
|
||||||
|
zero2_optimizer = LowLevelZeroOptimizer(zero2_optimizer,
|
||||||
|
overlap_communication=True,
|
||||||
|
partition_grad=True,
|
||||||
|
initial_scale=32,
|
||||||
|
clip_grad_norm=1.0)
|
||||||
|
|
||||||
|
# create
|
||||||
|
input_data = torch.rand(32, 128).cuda().half()
|
||||||
|
|
||||||
|
# forward
|
||||||
|
zero1_output = zero1_model(input_data)
|
||||||
|
zero2_output = zero2_model(input_data)
|
||||||
|
check_completely_equal(zero1_output, zero2_output)
|
||||||
|
|
||||||
|
# backward
|
||||||
|
zero1_optimizer.backward(zero1_output.mean().float())
|
||||||
|
zero2_optimizer.backward(zero2_output.mean().float())
|
||||||
|
|
||||||
|
# check grad
|
||||||
|
# as this param is small, the backward reduction
|
||||||
|
# will not be fired
|
||||||
|
for z1p, z2p in zip(zero1_model.parameters(), zero2_model.parameters()):
|
||||||
|
check_completely_equal(z1p.grad, z2p.grad)
|
||||||
|
|
||||||
|
# step
|
||||||
|
zero1_optimizer.sync_grad()
|
||||||
|
zero2_optimizer.sync_grad()
|
||||||
|
|
||||||
|
# step
|
||||||
|
zero1_optimizer.step()
|
||||||
|
zero2_optimizer.step()
|
||||||
|
|
||||||
|
# check updated param
|
||||||
|
for z1p, z2p in zip(zero1_model.parameters(), zero2_model.parameters()):
|
||||||
|
check_completely_equal(z1p.data, z2p.data)
|
||||||
|
|
||||||
|
|
||||||
|
def exam_zero_1_grad_clip():
|
||||||
|
# create models
|
||||||
|
zero_model = TestModel()
|
||||||
|
torch_model = copy.deepcopy(zero_model)
|
||||||
|
|
||||||
|
zero_model = zero_model.cuda().half()
|
||||||
|
torch_model = DDP(torch_model.cuda())
|
||||||
|
|
||||||
|
# create optimizer
|
||||||
|
zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=0.001)
|
||||||
|
|
||||||
|
# we only test stage 1 here
|
||||||
|
# in `check_sharded_param_consistency.py`, we will test whether
|
||||||
|
# level 1 and 2 will produce exactly the same results
|
||||||
|
zero_optimizer = LowLevelZeroOptimizer(zero_optimizer,
|
||||||
|
overlap_communication=True,
|
||||||
|
initial_scale=1,
|
||||||
|
clip_grad_norm=1.0)
|
||||||
|
|
||||||
|
torch_optimizer = torch.optim.Adam(torch_model.parameters(), lr=0.001)
|
||||||
|
|
||||||
|
# create
|
||||||
|
input_data = torch.rand(32, 128).cuda()
|
||||||
|
|
||||||
|
# zero-dp forward
|
||||||
|
zero_output = zero_model(input_data.half())
|
||||||
|
|
||||||
|
# torch-ddp forward
|
||||||
|
torch_output = torch_model(input_data)
|
||||||
|
check_equal(zero_output, torch_output)
|
||||||
|
|
||||||
|
# zero-dp backward
|
||||||
|
zero_optimizer.backward(zero_output.mean().float())
|
||||||
|
|
||||||
|
# torch-ddp backward
|
||||||
|
torch_output.mean().backward()
|
||||||
|
|
||||||
|
# check grad
|
||||||
|
for p, z1p in zip(torch_model.parameters(), zero_model.parameters()):
|
||||||
|
check_equal(p.grad, z1p.grad)
|
||||||
|
|
||||||
|
# zero-dp step
|
||||||
|
zero_optimizer.sync_grad()
|
||||||
|
zero_optimizer.step()
|
||||||
|
|
||||||
|
# torch ddp step
|
||||||
|
torch.nn.utils.clip_grad_norm_(torch_model.parameters(), 1.0)
|
||||||
|
torch_optimizer.step()
|
||||||
|
|
||||||
|
# check updated param
|
||||||
|
for p, z1p in zip(torch_model.parameters(), zero_model.parameters()):
|
||||||
|
check_equal(p.data, z1p.data, atol=5e-4)
|
||||||
|
|
||||||
|
|
||||||
|
def run_dist(rank, world_size, port):
|
||||||
|
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
|
||||||
|
|
||||||
|
exam_zero_1_2_grad_clip()
|
||||||
|
exam_zero_1_grad_clip()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.dist
|
||||||
|
def test_grad_clip():
|
||||||
|
world_size = 2
|
||||||
|
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||||
|
mp.spawn(run_func, nprocs=world_size)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_grad_clip()
|
Loading…
Reference in New Issue