update buffer size calculation (#3871)

pull/3873/head
Haichen Huang 2023-05-31 18:23:58 +08:00 committed by GitHub
parent dbb9659099
commit 1ee247a51c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 32 additions and 16 deletions

View File

@ -1,7 +1,7 @@
from collections import defaultdict
from copy import copy
from functools import partial
from typing import Any, Iterable, Mapping
from typing import Any, Callable, Iterable, Mapping
import torch
import torch.distributed as dist
@ -18,9 +18,19 @@ from colossalai.elixir.tensor import OutplaceTensor
from colossalai.utils.model.experimental import LazyTensor
def is_leaf_module(m: nn.Module):
def calc_module_buffer(m: nn.Module, fused_check_func: Callable) -> int:
special_modules = [nn.MultiheadAttention]
return type(m) in special_modules
buffer_size = 0
if type(m) in special_modules:
for p in m.parameters():
if p.requires_grad:
buffer_size += p.numel()
else:
for p in m.parameters(recurse=False):
if p.requires_grad and not fused_check_func(p):
buffer_size += p.numel()
return buffer_size
def get_param_optim_data(param_data: torch.Tensor, param_dtype: torch.dtype):
@ -170,20 +180,18 @@ class ElixirModule(nn.Module):
def __init_buffer_storage(self):
buffer_size = 0
for submodule in self.modules():
sum_param_size = 0
recurse_flag = is_leaf_module(submodule)
for param in submodule.parameters(recurse=recurse_flag):
if not param.requires_grad:
continue
assert param.dtype == self.dtype
sum_param_size += param.numel()
buffer_size = max(buffer_size, sum_param_size)
sub_size = calc_module_buffer(submodule, self.fetcher.is_in_fused)
buffer_size = max(buffer_size, sub_size)
self.buffer = BufferStore(buffer_size, self.dtype)
print('module buffer', self.buffer)
def _gradient_handler(self, grad: torch.Tensor, param: nn.Parameter):
# create an empty tensor
fake_grad = self.buffer.empty_like(grad)
if param.numel() <= self.buffer.buffer_size:
fake_grad = self.buffer.empty_like(grad)
else:
fake_grad = torch.empty_like(grad)
fake_grad.storage().resize_(0)
with torch._C.DisableTorchFunction():
chunk = self.fetcher.get_one_chunk(param)

View File

@ -70,6 +70,7 @@ class ElixirOptimizer(colo_optim.ColossalaiOptimizer):
if self.clipping_flag:
assert norm_type == 2.0, 'ElixirOptimizer only supports L2 norm now'
self.max_fake_param_size = 0
self.__init__optimizer()
# Grad scaler
@ -90,6 +91,7 @@ class ElixirOptimizer(colo_optim.ColossalaiOptimizer):
if init_step:
# allocate memory before training
self.__zero_step()
torch.cuda.empty_cache()
if self.clipping_flag:
for param_chunk in self.param_chunk_set:
@ -98,10 +100,15 @@ class ElixirOptimizer(colo_optim.ColossalaiOptimizer):
def __zero_step(self):
torch.cuda.empty_cache()
cpu_buffer = BufferStore(self.module.buffer.buffer_size, self.module.buffer.buffer_dtype, 'cpu')
buffer_dict = dict(cpu=cpu_buffer, cuda=self.module.buffer)
for _, zero_buffer in buffer_dict.items():
zero_buffer.zeros()
compute_type = self.module.buffer.buffer_dtype
device_list = ['cpu', 'cuda']
buffer_dict = dict()
for device in device_list:
temp_buffer = BufferStore(self.max_fake_param_size, compute_type, device)
buffer_dict[device] = temp_buffer
for _, temp_buffer in buffer_dict.items():
temp_buffer.zeros()
for group in self.param_groups:
for fake_param in group['params']:
@ -263,6 +270,7 @@ class ElixirOptimizer(colo_optim.ColossalaiOptimizer):
fake_param = torch.nn.Parameter(torch.empty([0], device=grad_device))
self.param_to_optim_chunk[fake_param] = param_chunk.paired_chunk
self.param_to_range[fake_param] = range_pair
self.max_fake_param_size = max(self.max_fake_param_size, range_pair[1] - range_pair[0])
fake_params_list.append(fake_param)