mirror of https://github.com/hpcaitech/ColossalAI
update buffer size calculation (#3871)
parent
dbb9659099
commit
1ee247a51c
|
@ -1,7 +1,7 @@
|
|||
from collections import defaultdict
|
||||
from copy import copy
|
||||
from functools import partial
|
||||
from typing import Any, Iterable, Mapping
|
||||
from typing import Any, Callable, Iterable, Mapping
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
@ -18,9 +18,19 @@ from colossalai.elixir.tensor import OutplaceTensor
|
|||
from colossalai.utils.model.experimental import LazyTensor
|
||||
|
||||
|
||||
def is_leaf_module(m: nn.Module):
|
||||
def calc_module_buffer(m: nn.Module, fused_check_func: Callable) -> int:
|
||||
special_modules = [nn.MultiheadAttention]
|
||||
return type(m) in special_modules
|
||||
buffer_size = 0
|
||||
if type(m) in special_modules:
|
||||
for p in m.parameters():
|
||||
if p.requires_grad:
|
||||
buffer_size += p.numel()
|
||||
else:
|
||||
for p in m.parameters(recurse=False):
|
||||
if p.requires_grad and not fused_check_func(p):
|
||||
buffer_size += p.numel()
|
||||
|
||||
return buffer_size
|
||||
|
||||
|
||||
def get_param_optim_data(param_data: torch.Tensor, param_dtype: torch.dtype):
|
||||
|
@ -170,20 +180,18 @@ class ElixirModule(nn.Module):
|
|||
def __init_buffer_storage(self):
|
||||
buffer_size = 0
|
||||
for submodule in self.modules():
|
||||
sum_param_size = 0
|
||||
recurse_flag = is_leaf_module(submodule)
|
||||
for param in submodule.parameters(recurse=recurse_flag):
|
||||
if not param.requires_grad:
|
||||
continue
|
||||
assert param.dtype == self.dtype
|
||||
sum_param_size += param.numel()
|
||||
buffer_size = max(buffer_size, sum_param_size)
|
||||
sub_size = calc_module_buffer(submodule, self.fetcher.is_in_fused)
|
||||
buffer_size = max(buffer_size, sub_size)
|
||||
self.buffer = BufferStore(buffer_size, self.dtype)
|
||||
print('module buffer', self.buffer)
|
||||
|
||||
def _gradient_handler(self, grad: torch.Tensor, param: nn.Parameter):
|
||||
# create an empty tensor
|
||||
fake_grad = self.buffer.empty_like(grad)
|
||||
if param.numel() <= self.buffer.buffer_size:
|
||||
fake_grad = self.buffer.empty_like(grad)
|
||||
else:
|
||||
fake_grad = torch.empty_like(grad)
|
||||
fake_grad.storage().resize_(0)
|
||||
|
||||
with torch._C.DisableTorchFunction():
|
||||
chunk = self.fetcher.get_one_chunk(param)
|
||||
|
|
|
@ -70,6 +70,7 @@ class ElixirOptimizer(colo_optim.ColossalaiOptimizer):
|
|||
if self.clipping_flag:
|
||||
assert norm_type == 2.0, 'ElixirOptimizer only supports L2 norm now'
|
||||
|
||||
self.max_fake_param_size = 0
|
||||
self.__init__optimizer()
|
||||
|
||||
# Grad scaler
|
||||
|
@ -90,6 +91,7 @@ class ElixirOptimizer(colo_optim.ColossalaiOptimizer):
|
|||
if init_step:
|
||||
# allocate memory before training
|
||||
self.__zero_step()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if self.clipping_flag:
|
||||
for param_chunk in self.param_chunk_set:
|
||||
|
@ -98,10 +100,15 @@ class ElixirOptimizer(colo_optim.ColossalaiOptimizer):
|
|||
def __zero_step(self):
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
cpu_buffer = BufferStore(self.module.buffer.buffer_size, self.module.buffer.buffer_dtype, 'cpu')
|
||||
buffer_dict = dict(cpu=cpu_buffer, cuda=self.module.buffer)
|
||||
for _, zero_buffer in buffer_dict.items():
|
||||
zero_buffer.zeros()
|
||||
compute_type = self.module.buffer.buffer_dtype
|
||||
device_list = ['cpu', 'cuda']
|
||||
buffer_dict = dict()
|
||||
|
||||
for device in device_list:
|
||||
temp_buffer = BufferStore(self.max_fake_param_size, compute_type, device)
|
||||
buffer_dict[device] = temp_buffer
|
||||
for _, temp_buffer in buffer_dict.items():
|
||||
temp_buffer.zeros()
|
||||
|
||||
for group in self.param_groups:
|
||||
for fake_param in group['params']:
|
||||
|
@ -263,6 +270,7 @@ class ElixirOptimizer(colo_optim.ColossalaiOptimizer):
|
|||
fake_param = torch.nn.Parameter(torch.empty([0], device=grad_device))
|
||||
self.param_to_optim_chunk[fake_param] = param_chunk.paired_chunk
|
||||
self.param_to_range[fake_param] = range_pair
|
||||
self.max_fake_param_size = max(self.max_fake_param_size, range_pair[1] - range_pair[0])
|
||||
|
||||
fake_params_list.append(fake_param)
|
||||
|
||||
|
|
Loading…
Reference in New Issue