[hotfix] fix lightning error (#2529)

pull/2534/head
HELSON 2023-01-31 10:40:39 +08:00 committed by GitHub
parent b55deb0662
commit a4ed9125ac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 119 additions and 102 deletions

View File

@ -5,6 +5,7 @@ from typing import Dict, Iterable, List, Optional, Set
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn
from colossalai.gemini.chunk import Chunk, ChunkManager, TensorState from colossalai.gemini.chunk import Chunk, ChunkManager, TensorState
from colossalai.gemini.gemini_mgr import GeminiManager from colossalai.gemini.gemini_mgr import GeminiManager
@ -218,11 +219,15 @@ class ZeroDDP(ColoDDP):
self.chunk_manager: ChunkManager = gemini_manager.chunk_manager self.chunk_manager: ChunkManager = gemini_manager.chunk_manager
self.force_outputs_fp32 = force_outputs_fp32 self.force_outputs_fp32 = force_outputs_fp32
self.param_op_hook = GeminiZeROHook(gemini_manager) self.param_op_hook = GeminiZeROHook(gemini_manager)
self.fp32_params: List[ColoTensor] = [] self.fp32_params: List[ColoTensor] = list()
self.fp16_params: List[ColoParameter] = list()
self.overflow_counter = 0 self.overflow_counter = 0
self.grads_device: Dict[torch.Tensor, torch.device] = {} self.grads_device: Dict[torch.Tensor, torch.device] = dict()
self.param2name: Dict[nn.Parameter, str] = dict()
self.name2param: Dict[str, nn.Parameter] = dict()
cpu_offload = self.gemini_manager.policy_name != 'cuda' self._cast_buffers()
self._logger = get_dist_logger()
if self.gemini_manager._premade_memstats_: if self.gemini_manager._premade_memstats_:
# build chunk in param runtime visited order. # build chunk in param runtime visited order.
@ -234,50 +239,17 @@ class ZeroDDP(ColoDDP):
for p in module.parameters(): for p in module.parameters():
param_order.append(p) param_order.append(p)
ddp_pg = ColoProcessGroup() self._init_chunks(param_order=param_order,
for p in param_order.generate(): strict_ddp_mode=strict_ddp_mode,
assert isinstance(p, ColoParameter) cpu_offload=self.gemini_manager.policy_name != 'cuda',
pin_memory=pin_memory)
if strict_ddp_mode: for name, param in module.named_parameters():
if not p.is_replicate(): self.param2name[param] = name
p.set_dist_spec(ReplicaSpec()) for m_name, m_var in module.named_modules():
p.set_process_group(pg=ddp_pg) for p_name, p_var in m_var.named_parameters(recurse=False):
param_name = m_name + '.' + p_name if m_name else p_name
if is_ddp_ignored(p): self.name2param[param_name] = p_var
p.data = p.data.to(device=get_current_device(), dtype=torch.float16)
continue
fp32_data = p.data.float()
fp32_p = ColoTensor(fp32_data, spec=ColoTensorSpec(p.process_group))
p.data = p.data.half()
dp_world_size = p.process_group.dp_world_size()
self.chunk_manager.register_tensor(tensor=p,
group_type='fp16_param',
config_key=dp_world_size,
cpu_offload=cpu_offload,
pin_memory=pin_memory)
self.chunk_manager.register_tensor(tensor=fp32_p,
group_type='fp32_param',
config_key=dp_world_size,
cpu_offload=cpu_offload,
pin_memory=pin_memory)
self.fp32_params.append(fp32_p)
self.grads_device[p] = self.gemini_manager.default_device
self.chunk_manager.close_all_groups()
self._cast_buffers()
params_list = [p for p in param_order.generate() if not is_ddp_ignored(p)]
for p, fp32_p in zip(params_list, self.fp32_params):
chunk_16 = self.chunk_manager.get_chunk(p)
chunk_32 = self.chunk_manager.get_chunk(fp32_p)
chunk_32.init_pair(chunk_16)
# keep gathered chunks are in CUDA
if chunk_16.keep_gathered:
self.grads_device[p] = get_current_device()
self._logger = get_dist_logger()
def _post_forward(self): def _post_forward(self):
"""This function is only triggered for inference. """This function is only triggered for inference.
@ -318,10 +290,23 @@ class ZeroDDP(ColoDDP):
continue continue
p.grad = None p.grad = None
def _pre_bacward(self):
# set a visit label for all parameters
# the label is used to check whether the parameter is correctly reduced
for param in self.param2name:
if not is_ddp_ignored(param):
setattr(param, "_gemini_reduced", False)
def _post_backward(self): def _post_backward(self):
if self.chunk_manager.accessed_mem != 0: if self.chunk_manager.accessed_mem != 0:
error_params = ["Reduction failed at followed parameters:"]
for param in self.param2name:
if not is_ddp_ignored(param) and not getattr(param, "_gemini_reduced"):
error_params.append(self.param2name[param])
error_str = "\n\t".join(error_params)
raise RuntimeError("ZERO DDP error: the synchronization of gradients doesn't exit properly.", raise RuntimeError("ZERO DDP error: the synchronization of gradients doesn't exit properly.",
"The most possible reason is that the model is not compatible with ZeroDDP.") "The most possible reason is that the model is not compatible with ZeroDDP.\n",
f"{error_str}")
self._setup_grads_ptr() self._setup_grads_ptr()
self._logger.debug( self._logger.debug(
f'comp cuda demand time: {self.gemini_manager._comp_cuda_demand_time}, layout time: {self.gemini_manager._layout_time}, evict time: {self.gemini_manager._evict_time}, CPU->CUDA vol: {self.gemini_manager._h2d_volume}B, CUDA->CPU vol: {self.gemini_manager._d2h_volume}' f'comp cuda demand time: {self.gemini_manager._comp_cuda_demand_time}, layout time: {self.gemini_manager._layout_time}, evict time: {self.gemini_manager._evict_time}, CPU->CUDA vol: {self.gemini_manager._h2d_volume}B, CUDA->CPU vol: {self.gemini_manager._d2h_volume}'
@ -329,6 +314,7 @@ class ZeroDDP(ColoDDP):
self.gemini_manager.post_iter() self.gemini_manager.post_iter()
def backward(self, loss: torch.Tensor): def backward(self, loss: torch.Tensor):
self._pre_bacward()
with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(self.param_op_hook): with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(self.param_op_hook):
loss.backward() loss.backward()
self._post_backward() self._post_backward()
@ -343,7 +329,9 @@ class ZeroDDP(ColoDDP):
free_storage(empty_grad) free_storage(empty_grad)
with torch._C.DisableTorchFunction(): with torch._C.DisableTorchFunction():
chunk = self.chunk_manager.get_chunk(p) chunk = self.chunk_manager.get_chunk(p)
assert chunk.tensors_info[p].state == TensorState.HOLD_AFTER_BWD if chunk.tensors_info[p].state != TensorState.HOLD_AFTER_BWD:
raise RuntimeError(f"Parameter `{self.param2name[p]}` failed at the gradient reduction. "
"Some unsupported torch function is operated upon this parameter.")
self.chunk_manager.trans_tensor_state(p, TensorState.READY_FOR_REDUCE) self.chunk_manager.trans_tensor_state(p, TensorState.READY_FOR_REDUCE)
chunk.copy_tensor_to_chunk_slice(p, grad) chunk.copy_tensor_to_chunk_slice(p, grad)
reduced = self.chunk_manager.reduce_chunk(chunk) reduced = self.chunk_manager.reduce_chunk(chunk)
@ -367,30 +355,7 @@ class ZeroDDP(ColoDDP):
for tensor in chunk.get_tensors(): for tensor in chunk.get_tensors():
self.grads_device[tensor] = device self.grads_device[tensor] = device
def state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: bool = True, strict: bool = True): def state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: bool = True):
"""
Args:
strict (bool): whether to reture the whole model state as the pytorch `Module.state_dict()`
Returns:
dict:
a dictionary containing a whole state of the module
Example:
>>> module.state_dict().keys()
['bias', 'weight']
"""
if strict:
assert keep_vars is False, "`state_dict` with parameter, `keep_vars=True`, is not supported now."
torch_model = get_static_torch_model(zero_ddp_model=self, only_rank_0=only_rank_0)
return torch_model.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
return self._non_strict_state_dict(destination=destination,
prefix=prefix,
keep_vars=keep_vars,
only_rank_0=only_rank_0)
def _non_strict_state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: bool = True):
"""Returns a dictionary containing a whole state of the module. """Returns a dictionary containing a whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are included. Both parameters and persistent buffers (e.g. running averages) are included.
@ -461,19 +426,24 @@ class ZeroDDP(ColoDDP):
""" """
assert keep_vars is False, "`state_dict` with parameter, `keep_vars=True`, is not supported now." assert keep_vars is False, "`state_dict` with parameter, `keep_vars=True`, is not supported now."
# get copies of fp32 parameters in CPU
param_to_save_data = self._get_param_to_save_data(self.fp32_params, only_rank_0) param_to_save_data = self._get_param_to_save_data(self.fp32_params, only_rank_0)
ddp_param_list = [] # get the mapping between copies and fp16 parameters
for name, param in self.named_parameters(): p_mapping = dict()
if is_ddp_ignored(param): for p, fp32_p in zip(self.fp16_params, self.fp32_params):
# deal with ddp ignored parameters name = self.param2name[p]
destination[prefix + name] = param if keep_vars else param.detach() assert fp32_p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name)
else: record_parameter = param_to_save_data[fp32_p]
ddp_param_list.append((name, param)) p_mapping[p] = record_parameter
for (name, p), fp32_p in zip(ddp_param_list, self.fp32_params): for name, param in self.name2param.items():
if p is not None: if param is not None:
assert fp32_p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name) if is_ddp_ignored(param):
record_parameter = param_to_save_data[fp32_p] # deal with ddp ignored parameters
destination[prefix + name] = record_parameter destination[prefix + name] = param if keep_vars else param.detach()
else:
destination[prefix + name] = p_mapping[param]
del p_mapping
del param_to_save_data
# save all buffers # save all buffers
for name, buf in self.named_buffers(): for name, buf in self.named_buffers():
@ -605,17 +575,15 @@ class ZeroDDP(ColoDDP):
def load_fp32_parameter(chunk_slice, data): def load_fp32_parameter(chunk_slice, data):
chunk_slice.copy_(data.flatten()) chunk_slice.copy_(data.flatten())
ddp_param_list = []
for name, param in self.named_parameters(): for name, param in self.named_parameters():
if is_ddp_ignored(param): if is_ddp_ignored(param):
# deal with ddp ignored parameters # deal with ddp ignored parameters
load(name, param, param.copy_) load(name, param, param.copy_)
else:
ddp_param_list.append((name, param))
fp32_to_name = dict() fp32_to_name = dict()
for (name, p), fp32_p in zip(ddp_param_list, self.fp32_params): for p, fp32_p in zip(self.fp16_params, self.fp32_params):
if p is not None: if p is not None:
name = self.param2name[p]
fp32_to_name[fp32_p] = name fp32_to_name[fp32_p] = name
chunk_list = self.chunk_manager.get_chunks(self.fp32_params) chunk_list = self.chunk_manager.get_chunks(self.fp32_params)
@ -662,6 +630,60 @@ class ZeroDDP(ColoDDP):
if input_name not in local_state: if input_name not in local_state:
unexpected_keys.append(key) unexpected_keys.append(key)
def _init_chunks(self, param_order, strict_ddp_mode: bool, cpu_offload: bool, pin_memory: bool):
ddp_pg = ColoProcessGroup()
for p in param_order.generate():
assert isinstance(p, ColoParameter)
# gather sharded parameters in the strict ddp mode
if strict_ddp_mode:
if not p.is_replicate():
p.set_dist_spec(ReplicaSpec())
p.set_process_group(pg=ddp_pg)
# ignore the parameters with no gradient
if not p.requires_grad:
self.set_params_to_ignore([p])
# move ignored parameters to CUDA
if is_ddp_ignored(p):
p.data = p.data.to(device=get_current_device(), dtype=torch.float16)
continue
# create a fp32 parameter
fp32_data = p.data.float()
fp32_p = ColoTensor(fp32_data, spec=ColoTensorSpec(p.process_group))
# create a fp16 parameter
p.data = p.data.half()
# register the fp16 parameter and fp32 parameter in the chunk manager
dp_world_size = p.process_group.dp_world_size()
self.chunk_manager.register_tensor(tensor=p,
group_type='fp16_param',
config_key=dp_world_size,
cpu_offload=cpu_offload,
pin_memory=pin_memory)
self.chunk_manager.register_tensor(tensor=fp32_p,
group_type='fp32_param',
config_key=dp_world_size,
cpu_offload=cpu_offload,
pin_memory=pin_memory)
self.fp16_params.append(p)
self.fp32_params.append(fp32_p)
self.grads_device[p] = self.gemini_manager.default_device
self.chunk_manager.close_all_groups()
for p, fp32_p in zip(self.fp16_params, self.fp32_params):
chunk_16 = self.chunk_manager.get_chunk(p)
chunk_32 = self.chunk_manager.get_chunk(fp32_p)
chunk_32.init_pair(chunk_16)
# keep gathered chunks are in CUDA
if chunk_16.keep_gathered:
self.grads_device[p] = get_current_device()
def _cast_buffers(self): def _cast_buffers(self):
for buffer in self.module.buffers(): for buffer in self.module.buffers():
buffer.data = buffer.cuda() buffer.data = buffer.cuda()

View File

@ -49,6 +49,10 @@ class GeminiDDP(ZeroDDP):
all parameters will be compacted into one small chunk. all parameters will be compacted into one small chunk.
memstats (MemStats, optional) the memory statistics collector by a runtime memory tracer. memstats (MemStats, optional) the memory statistics collector by a runtime memory tracer.
""" """
# some ugly hotfix for the compatibility with Lightning
if search_range_mb is None:
search_range_mb = 32
chunk_manager = init_chunk_manager(model=module, chunk_manager = init_chunk_manager(model=module,
init_device=device, init_device=device,
hidden_dim=hidden_dim, hidden_dim=hidden_dim,

View File

@ -80,13 +80,11 @@ def get_static_torch_model(zero_ddp_model,
from colossalai.nn.parallel import ZeroDDP from colossalai.nn.parallel import ZeroDDP
assert isinstance(zero_ddp_model, ZeroDDP) assert isinstance(zero_ddp_model, ZeroDDP)
state_dict = zero_ddp_model.state_dict(only_rank_0=only_rank_0, strict=False) state_dict = zero_ddp_model.state_dict(only_rank_0=only_rank_0)
colo_model = zero_ddp_model.module colo_model = zero_ddp_model.module
torch_model = _get_shallow_copy_model(colo_model) torch_model = _get_shallow_copy_model(colo_model)
if not only_rank_0 or dist.get_rank() == 0: if not only_rank_0 or dist.get_rank() == 0:
# record the mapping relationship between colo parameters and torch parameters
colo_to_torch = dict()
for (name, colo_module), (_, torch_module) in \ for (name, colo_module), (_, torch_module) in \
zip(_get_dfs_module_list(colo_model), _get_dfs_module_list(torch_model)): zip(_get_dfs_module_list(colo_model), _get_dfs_module_list(torch_model)):
# clean the parameter list of the new torch module # clean the parameter list of the new torch module
@ -94,17 +92,10 @@ def get_static_torch_model(zero_ddp_model,
for sufix_param_name, param in colo_module.named_parameters(recurse=False): for sufix_param_name, param in colo_module.named_parameters(recurse=False):
# get the full name of the parameter # get the full name of the parameter
full_param_name = name + ('.' if name else '') + sufix_param_name full_param_name = name + ('.' if name else '') + sufix_param_name
assert full_param_name in state_dict, \
if full_param_name not in state_dict: f"Can not find parameter `{full_param_name}` in the GeminiDDP module"
# this means the parameter is shared by multiple modules state_param = state_dict[full_param_name]
# we should use colo_to_torch to get the torch parameter created before torch_param = torch.nn.Parameter(state_param.data.to(device=device, dtype=dtype))
assert param in colo_to_torch, f"can not find parameter `{full_param_name}` in the GeminiDDP module"
torch_param = colo_to_torch[param]
else:
# we meet the parameter the first time, just use the state dict to get the data
state_param = state_dict[full_param_name]
torch_param = torch.nn.Parameter(state_param.data.to(device=device, dtype=dtype))
colo_to_torch[param] = torch_param
setattr(torch_module, sufix_param_name, torch_param) setattr(torch_module, sufix_param_name, torch_param)
dist.barrier() dist.barrier()