[hotfix] fix lightning error (#2529)

pull/2534/head
HELSON 2023-01-31 10:40:39 +08:00 committed by GitHub
parent b55deb0662
commit a4ed9125ac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 119 additions and 102 deletions

View File

@ -5,6 +5,7 @@ from typing import Dict, Iterable, List, Optional, Set
import torch
import torch.distributed as dist
import torch.nn as nn
from colossalai.gemini.chunk import Chunk, ChunkManager, TensorState
from colossalai.gemini.gemini_mgr import GeminiManager
@ -218,11 +219,15 @@ class ZeroDDP(ColoDDP):
self.chunk_manager: ChunkManager = gemini_manager.chunk_manager
self.force_outputs_fp32 = force_outputs_fp32
self.param_op_hook = GeminiZeROHook(gemini_manager)
self.fp32_params: List[ColoTensor] = []
self.fp32_params: List[ColoTensor] = list()
self.fp16_params: List[ColoParameter] = list()
self.overflow_counter = 0
self.grads_device: Dict[torch.Tensor, torch.device] = {}
self.grads_device: Dict[torch.Tensor, torch.device] = dict()
self.param2name: Dict[nn.Parameter, str] = dict()
self.name2param: Dict[str, nn.Parameter] = dict()
cpu_offload = self.gemini_manager.policy_name != 'cuda'
self._cast_buffers()
self._logger = get_dist_logger()
if self.gemini_manager._premade_memstats_:
# build chunk in param runtime visited order.
@ -234,50 +239,17 @@ class ZeroDDP(ColoDDP):
for p in module.parameters():
param_order.append(p)
ddp_pg = ColoProcessGroup()
for p in param_order.generate():
assert isinstance(p, ColoParameter)
if strict_ddp_mode:
if not p.is_replicate():
p.set_dist_spec(ReplicaSpec())
p.set_process_group(pg=ddp_pg)
if is_ddp_ignored(p):
p.data = p.data.to(device=get_current_device(), dtype=torch.float16)
continue
fp32_data = p.data.float()
fp32_p = ColoTensor(fp32_data, spec=ColoTensorSpec(p.process_group))
p.data = p.data.half()
dp_world_size = p.process_group.dp_world_size()
self.chunk_manager.register_tensor(tensor=p,
group_type='fp16_param',
config_key=dp_world_size,
cpu_offload=cpu_offload,
self._init_chunks(param_order=param_order,
strict_ddp_mode=strict_ddp_mode,
cpu_offload=self.gemini_manager.policy_name != 'cuda',
pin_memory=pin_memory)
self.chunk_manager.register_tensor(tensor=fp32_p,
group_type='fp32_param',
config_key=dp_world_size,
cpu_offload=cpu_offload,
pin_memory=pin_memory)
self.fp32_params.append(fp32_p)
self.grads_device[p] = self.gemini_manager.default_device
self.chunk_manager.close_all_groups()
self._cast_buffers()
params_list = [p for p in param_order.generate() if not is_ddp_ignored(p)]
for p, fp32_p in zip(params_list, self.fp32_params):
chunk_16 = self.chunk_manager.get_chunk(p)
chunk_32 = self.chunk_manager.get_chunk(fp32_p)
chunk_32.init_pair(chunk_16)
# keep gathered chunks are in CUDA
if chunk_16.keep_gathered:
self.grads_device[p] = get_current_device()
self._logger = get_dist_logger()
for name, param in module.named_parameters():
self.param2name[param] = name
for m_name, m_var in module.named_modules():
for p_name, p_var in m_var.named_parameters(recurse=False):
param_name = m_name + '.' + p_name if m_name else p_name
self.name2param[param_name] = p_var
def _post_forward(self):
"""This function is only triggered for inference.
@ -318,10 +290,23 @@ class ZeroDDP(ColoDDP):
continue
p.grad = None
def _pre_bacward(self):
# set a visit label for all parameters
# the label is used to check whether the parameter is correctly reduced
for param in self.param2name:
if not is_ddp_ignored(param):
setattr(param, "_gemini_reduced", False)
def _post_backward(self):
if self.chunk_manager.accessed_mem != 0:
error_params = ["Reduction failed at followed parameters:"]
for param in self.param2name:
if not is_ddp_ignored(param) and not getattr(param, "_gemini_reduced"):
error_params.append(self.param2name[param])
error_str = "\n\t".join(error_params)
raise RuntimeError("ZERO DDP error: the synchronization of gradients doesn't exit properly.",
"The most possible reason is that the model is not compatible with ZeroDDP.")
"The most possible reason is that the model is not compatible with ZeroDDP.\n",
f"{error_str}")
self._setup_grads_ptr()
self._logger.debug(
f'comp cuda demand time: {self.gemini_manager._comp_cuda_demand_time}, layout time: {self.gemini_manager._layout_time}, evict time: {self.gemini_manager._evict_time}, CPU->CUDA vol: {self.gemini_manager._h2d_volume}B, CUDA->CPU vol: {self.gemini_manager._d2h_volume}'
@ -329,6 +314,7 @@ class ZeroDDP(ColoDDP):
self.gemini_manager.post_iter()
def backward(self, loss: torch.Tensor):
self._pre_bacward()
with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(self.param_op_hook):
loss.backward()
self._post_backward()
@ -343,7 +329,9 @@ class ZeroDDP(ColoDDP):
free_storage(empty_grad)
with torch._C.DisableTorchFunction():
chunk = self.chunk_manager.get_chunk(p)
assert chunk.tensors_info[p].state == TensorState.HOLD_AFTER_BWD
if chunk.tensors_info[p].state != TensorState.HOLD_AFTER_BWD:
raise RuntimeError(f"Parameter `{self.param2name[p]}` failed at the gradient reduction. "
"Some unsupported torch function is operated upon this parameter.")
self.chunk_manager.trans_tensor_state(p, TensorState.READY_FOR_REDUCE)
chunk.copy_tensor_to_chunk_slice(p, grad)
reduced = self.chunk_manager.reduce_chunk(chunk)
@ -367,30 +355,7 @@ class ZeroDDP(ColoDDP):
for tensor in chunk.get_tensors():
self.grads_device[tensor] = device
def state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: bool = True, strict: bool = True):
"""
Args:
strict (bool): whether to reture the whole model state as the pytorch `Module.state_dict()`
Returns:
dict:
a dictionary containing a whole state of the module
Example:
>>> module.state_dict().keys()
['bias', 'weight']
"""
if strict:
assert keep_vars is False, "`state_dict` with parameter, `keep_vars=True`, is not supported now."
torch_model = get_static_torch_model(zero_ddp_model=self, only_rank_0=only_rank_0)
return torch_model.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
return self._non_strict_state_dict(destination=destination,
prefix=prefix,
keep_vars=keep_vars,
only_rank_0=only_rank_0)
def _non_strict_state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: bool = True):
def state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: bool = True):
"""Returns a dictionary containing a whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are included.
@ -461,19 +426,24 @@ class ZeroDDP(ColoDDP):
"""
assert keep_vars is False, "`state_dict` with parameter, `keep_vars=True`, is not supported now."
# get copies of fp32 parameters in CPU
param_to_save_data = self._get_param_to_save_data(self.fp32_params, only_rank_0)
ddp_param_list = []
for name, param in self.named_parameters():
# get the mapping between copies and fp16 parameters
p_mapping = dict()
for p, fp32_p in zip(self.fp16_params, self.fp32_params):
name = self.param2name[p]
assert fp32_p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name)
record_parameter = param_to_save_data[fp32_p]
p_mapping[p] = record_parameter
for name, param in self.name2param.items():
if param is not None:
if is_ddp_ignored(param):
# deal with ddp ignored parameters
destination[prefix + name] = param if keep_vars else param.detach()
else:
ddp_param_list.append((name, param))
for (name, p), fp32_p in zip(ddp_param_list, self.fp32_params):
if p is not None:
assert fp32_p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name)
record_parameter = param_to_save_data[fp32_p]
destination[prefix + name] = record_parameter
destination[prefix + name] = p_mapping[param]
del p_mapping
del param_to_save_data
# save all buffers
for name, buf in self.named_buffers():
@ -605,17 +575,15 @@ class ZeroDDP(ColoDDP):
def load_fp32_parameter(chunk_slice, data):
chunk_slice.copy_(data.flatten())
ddp_param_list = []
for name, param in self.named_parameters():
if is_ddp_ignored(param):
# deal with ddp ignored parameters
load(name, param, param.copy_)
else:
ddp_param_list.append((name, param))
fp32_to_name = dict()
for (name, p), fp32_p in zip(ddp_param_list, self.fp32_params):
for p, fp32_p in zip(self.fp16_params, self.fp32_params):
if p is not None:
name = self.param2name[p]
fp32_to_name[fp32_p] = name
chunk_list = self.chunk_manager.get_chunks(self.fp32_params)
@ -662,6 +630,60 @@ class ZeroDDP(ColoDDP):
if input_name not in local_state:
unexpected_keys.append(key)
def _init_chunks(self, param_order, strict_ddp_mode: bool, cpu_offload: bool, pin_memory: bool):
ddp_pg = ColoProcessGroup()
for p in param_order.generate():
assert isinstance(p, ColoParameter)
# gather sharded parameters in the strict ddp mode
if strict_ddp_mode:
if not p.is_replicate():
p.set_dist_spec(ReplicaSpec())
p.set_process_group(pg=ddp_pg)
# ignore the parameters with no gradient
if not p.requires_grad:
self.set_params_to_ignore([p])
# move ignored parameters to CUDA
if is_ddp_ignored(p):
p.data = p.data.to(device=get_current_device(), dtype=torch.float16)
continue
# create a fp32 parameter
fp32_data = p.data.float()
fp32_p = ColoTensor(fp32_data, spec=ColoTensorSpec(p.process_group))
# create a fp16 parameter
p.data = p.data.half()
# register the fp16 parameter and fp32 parameter in the chunk manager
dp_world_size = p.process_group.dp_world_size()
self.chunk_manager.register_tensor(tensor=p,
group_type='fp16_param',
config_key=dp_world_size,
cpu_offload=cpu_offload,
pin_memory=pin_memory)
self.chunk_manager.register_tensor(tensor=fp32_p,
group_type='fp32_param',
config_key=dp_world_size,
cpu_offload=cpu_offload,
pin_memory=pin_memory)
self.fp16_params.append(p)
self.fp32_params.append(fp32_p)
self.grads_device[p] = self.gemini_manager.default_device
self.chunk_manager.close_all_groups()
for p, fp32_p in zip(self.fp16_params, self.fp32_params):
chunk_16 = self.chunk_manager.get_chunk(p)
chunk_32 = self.chunk_manager.get_chunk(fp32_p)
chunk_32.init_pair(chunk_16)
# keep gathered chunks are in CUDA
if chunk_16.keep_gathered:
self.grads_device[p] = get_current_device()
def _cast_buffers(self):
for buffer in self.module.buffers():
buffer.data = buffer.cuda()

View File

@ -49,6 +49,10 @@ class GeminiDDP(ZeroDDP):
all parameters will be compacted into one small chunk.
memstats (MemStats, optional) the memory statistics collector by a runtime memory tracer.
"""
# some ugly hotfix for the compatibility with Lightning
if search_range_mb is None:
search_range_mb = 32
chunk_manager = init_chunk_manager(model=module,
init_device=device,
hidden_dim=hidden_dim,

View File

@ -80,13 +80,11 @@ def get_static_torch_model(zero_ddp_model,
from colossalai.nn.parallel import ZeroDDP
assert isinstance(zero_ddp_model, ZeroDDP)
state_dict = zero_ddp_model.state_dict(only_rank_0=only_rank_0, strict=False)
state_dict = zero_ddp_model.state_dict(only_rank_0=only_rank_0)
colo_model = zero_ddp_model.module
torch_model = _get_shallow_copy_model(colo_model)
if not only_rank_0 or dist.get_rank() == 0:
# record the mapping relationship between colo parameters and torch parameters
colo_to_torch = dict()
for (name, colo_module), (_, torch_module) in \
zip(_get_dfs_module_list(colo_model), _get_dfs_module_list(torch_model)):
# clean the parameter list of the new torch module
@ -94,17 +92,10 @@ def get_static_torch_model(zero_ddp_model,
for sufix_param_name, param in colo_module.named_parameters(recurse=False):
# get the full name of the parameter
full_param_name = name + ('.' if name else '') + sufix_param_name
if full_param_name not in state_dict:
# this means the parameter is shared by multiple modules
# we should use colo_to_torch to get the torch parameter created before
assert param in colo_to_torch, f"can not find parameter `{full_param_name}` in the GeminiDDP module"
torch_param = colo_to_torch[param]
else:
# we meet the parameter the first time, just use the state dict to get the data
assert full_param_name in state_dict, \
f"Can not find parameter `{full_param_name}` in the GeminiDDP module"
state_param = state_dict[full_param_name]
torch_param = torch.nn.Parameter(state_param.data.to(device=device, dtype=dtype))
colo_to_torch[param] = torch_param
setattr(torch_module, sufix_param_name, torch_param)
dist.barrier()