mirror of https://github.com/hpcaitech/ColossalAI
[hotfix] fix lightning error (#2529)
parent
b55deb0662
commit
a4ed9125ac
|
@ -5,6 +5,7 @@ from typing import Dict, Iterable, List, Optional, Set
|
|||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.gemini.chunk import Chunk, ChunkManager, TensorState
|
||||
from colossalai.gemini.gemini_mgr import GeminiManager
|
||||
|
@ -218,11 +219,15 @@ class ZeroDDP(ColoDDP):
|
|||
self.chunk_manager: ChunkManager = gemini_manager.chunk_manager
|
||||
self.force_outputs_fp32 = force_outputs_fp32
|
||||
self.param_op_hook = GeminiZeROHook(gemini_manager)
|
||||
self.fp32_params: List[ColoTensor] = []
|
||||
self.fp32_params: List[ColoTensor] = list()
|
||||
self.fp16_params: List[ColoParameter] = list()
|
||||
self.overflow_counter = 0
|
||||
self.grads_device: Dict[torch.Tensor, torch.device] = {}
|
||||
self.grads_device: Dict[torch.Tensor, torch.device] = dict()
|
||||
self.param2name: Dict[nn.Parameter, str] = dict()
|
||||
self.name2param: Dict[str, nn.Parameter] = dict()
|
||||
|
||||
cpu_offload = self.gemini_manager.policy_name != 'cuda'
|
||||
self._cast_buffers()
|
||||
self._logger = get_dist_logger()
|
||||
|
||||
if self.gemini_manager._premade_memstats_:
|
||||
# build chunk in param runtime visited order.
|
||||
|
@ -234,50 +239,17 @@ class ZeroDDP(ColoDDP):
|
|||
for p in module.parameters():
|
||||
param_order.append(p)
|
||||
|
||||
ddp_pg = ColoProcessGroup()
|
||||
for p in param_order.generate():
|
||||
assert isinstance(p, ColoParameter)
|
||||
|
||||
if strict_ddp_mode:
|
||||
if not p.is_replicate():
|
||||
p.set_dist_spec(ReplicaSpec())
|
||||
p.set_process_group(pg=ddp_pg)
|
||||
|
||||
if is_ddp_ignored(p):
|
||||
p.data = p.data.to(device=get_current_device(), dtype=torch.float16)
|
||||
continue
|
||||
|
||||
fp32_data = p.data.float()
|
||||
fp32_p = ColoTensor(fp32_data, spec=ColoTensorSpec(p.process_group))
|
||||
p.data = p.data.half()
|
||||
dp_world_size = p.process_group.dp_world_size()
|
||||
self.chunk_manager.register_tensor(tensor=p,
|
||||
group_type='fp16_param',
|
||||
config_key=dp_world_size,
|
||||
cpu_offload=cpu_offload,
|
||||
self._init_chunks(param_order=param_order,
|
||||
strict_ddp_mode=strict_ddp_mode,
|
||||
cpu_offload=self.gemini_manager.policy_name != 'cuda',
|
||||
pin_memory=pin_memory)
|
||||
self.chunk_manager.register_tensor(tensor=fp32_p,
|
||||
group_type='fp32_param',
|
||||
config_key=dp_world_size,
|
||||
cpu_offload=cpu_offload,
|
||||
pin_memory=pin_memory)
|
||||
self.fp32_params.append(fp32_p)
|
||||
self.grads_device[p] = self.gemini_manager.default_device
|
||||
|
||||
self.chunk_manager.close_all_groups()
|
||||
self._cast_buffers()
|
||||
|
||||
params_list = [p for p in param_order.generate() if not is_ddp_ignored(p)]
|
||||
for p, fp32_p in zip(params_list, self.fp32_params):
|
||||
chunk_16 = self.chunk_manager.get_chunk(p)
|
||||
chunk_32 = self.chunk_manager.get_chunk(fp32_p)
|
||||
chunk_32.init_pair(chunk_16)
|
||||
|
||||
# keep gathered chunks are in CUDA
|
||||
if chunk_16.keep_gathered:
|
||||
self.grads_device[p] = get_current_device()
|
||||
|
||||
self._logger = get_dist_logger()
|
||||
for name, param in module.named_parameters():
|
||||
self.param2name[param] = name
|
||||
for m_name, m_var in module.named_modules():
|
||||
for p_name, p_var in m_var.named_parameters(recurse=False):
|
||||
param_name = m_name + '.' + p_name if m_name else p_name
|
||||
self.name2param[param_name] = p_var
|
||||
|
||||
def _post_forward(self):
|
||||
"""This function is only triggered for inference.
|
||||
|
@ -318,10 +290,23 @@ class ZeroDDP(ColoDDP):
|
|||
continue
|
||||
p.grad = None
|
||||
|
||||
def _pre_bacward(self):
|
||||
# set a visit label for all parameters
|
||||
# the label is used to check whether the parameter is correctly reduced
|
||||
for param in self.param2name:
|
||||
if not is_ddp_ignored(param):
|
||||
setattr(param, "_gemini_reduced", False)
|
||||
|
||||
def _post_backward(self):
|
||||
if self.chunk_manager.accessed_mem != 0:
|
||||
error_params = ["Reduction failed at followed parameters:"]
|
||||
for param in self.param2name:
|
||||
if not is_ddp_ignored(param) and not getattr(param, "_gemini_reduced"):
|
||||
error_params.append(self.param2name[param])
|
||||
error_str = "\n\t".join(error_params)
|
||||
raise RuntimeError("ZERO DDP error: the synchronization of gradients doesn't exit properly.",
|
||||
"The most possible reason is that the model is not compatible with ZeroDDP.")
|
||||
"The most possible reason is that the model is not compatible with ZeroDDP.\n",
|
||||
f"{error_str}")
|
||||
self._setup_grads_ptr()
|
||||
self._logger.debug(
|
||||
f'comp cuda demand time: {self.gemini_manager._comp_cuda_demand_time}, layout time: {self.gemini_manager._layout_time}, evict time: {self.gemini_manager._evict_time}, CPU->CUDA vol: {self.gemini_manager._h2d_volume}B, CUDA->CPU vol: {self.gemini_manager._d2h_volume}'
|
||||
|
@ -329,6 +314,7 @@ class ZeroDDP(ColoDDP):
|
|||
self.gemini_manager.post_iter()
|
||||
|
||||
def backward(self, loss: torch.Tensor):
|
||||
self._pre_bacward()
|
||||
with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(self.param_op_hook):
|
||||
loss.backward()
|
||||
self._post_backward()
|
||||
|
@ -343,7 +329,9 @@ class ZeroDDP(ColoDDP):
|
|||
free_storage(empty_grad)
|
||||
with torch._C.DisableTorchFunction():
|
||||
chunk = self.chunk_manager.get_chunk(p)
|
||||
assert chunk.tensors_info[p].state == TensorState.HOLD_AFTER_BWD
|
||||
if chunk.tensors_info[p].state != TensorState.HOLD_AFTER_BWD:
|
||||
raise RuntimeError(f"Parameter `{self.param2name[p]}` failed at the gradient reduction. "
|
||||
"Some unsupported torch function is operated upon this parameter.")
|
||||
self.chunk_manager.trans_tensor_state(p, TensorState.READY_FOR_REDUCE)
|
||||
chunk.copy_tensor_to_chunk_slice(p, grad)
|
||||
reduced = self.chunk_manager.reduce_chunk(chunk)
|
||||
|
@ -367,30 +355,7 @@ class ZeroDDP(ColoDDP):
|
|||
for tensor in chunk.get_tensors():
|
||||
self.grads_device[tensor] = device
|
||||
|
||||
def state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: bool = True, strict: bool = True):
|
||||
"""
|
||||
Args:
|
||||
strict (bool): whether to reture the whole model state as the pytorch `Module.state_dict()`
|
||||
|
||||
Returns:
|
||||
dict:
|
||||
a dictionary containing a whole state of the module
|
||||
|
||||
Example:
|
||||
|
||||
>>> module.state_dict().keys()
|
||||
['bias', 'weight']
|
||||
"""
|
||||
if strict:
|
||||
assert keep_vars is False, "`state_dict` with parameter, `keep_vars=True`, is not supported now."
|
||||
torch_model = get_static_torch_model(zero_ddp_model=self, only_rank_0=only_rank_0)
|
||||
return torch_model.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
|
||||
return self._non_strict_state_dict(destination=destination,
|
||||
prefix=prefix,
|
||||
keep_vars=keep_vars,
|
||||
only_rank_0=only_rank_0)
|
||||
|
||||
def _non_strict_state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: bool = True):
|
||||
def state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: bool = True):
|
||||
"""Returns a dictionary containing a whole state of the module.
|
||||
|
||||
Both parameters and persistent buffers (e.g. running averages) are included.
|
||||
|
@ -461,19 +426,24 @@ class ZeroDDP(ColoDDP):
|
|||
"""
|
||||
assert keep_vars is False, "`state_dict` with parameter, `keep_vars=True`, is not supported now."
|
||||
|
||||
# get copies of fp32 parameters in CPU
|
||||
param_to_save_data = self._get_param_to_save_data(self.fp32_params, only_rank_0)
|
||||
ddp_param_list = []
|
||||
for name, param in self.named_parameters():
|
||||
# get the mapping between copies and fp16 parameters
|
||||
p_mapping = dict()
|
||||
for p, fp32_p in zip(self.fp16_params, self.fp32_params):
|
||||
name = self.param2name[p]
|
||||
assert fp32_p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name)
|
||||
record_parameter = param_to_save_data[fp32_p]
|
||||
p_mapping[p] = record_parameter
|
||||
for name, param in self.name2param.items():
|
||||
if param is not None:
|
||||
if is_ddp_ignored(param):
|
||||
# deal with ddp ignored parameters
|
||||
destination[prefix + name] = param if keep_vars else param.detach()
|
||||
else:
|
||||
ddp_param_list.append((name, param))
|
||||
for (name, p), fp32_p in zip(ddp_param_list, self.fp32_params):
|
||||
if p is not None:
|
||||
assert fp32_p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name)
|
||||
record_parameter = param_to_save_data[fp32_p]
|
||||
destination[prefix + name] = record_parameter
|
||||
destination[prefix + name] = p_mapping[param]
|
||||
del p_mapping
|
||||
del param_to_save_data
|
||||
|
||||
# save all buffers
|
||||
for name, buf in self.named_buffers():
|
||||
|
@ -605,17 +575,15 @@ class ZeroDDP(ColoDDP):
|
|||
def load_fp32_parameter(chunk_slice, data):
|
||||
chunk_slice.copy_(data.flatten())
|
||||
|
||||
ddp_param_list = []
|
||||
for name, param in self.named_parameters():
|
||||
if is_ddp_ignored(param):
|
||||
# deal with ddp ignored parameters
|
||||
load(name, param, param.copy_)
|
||||
else:
|
||||
ddp_param_list.append((name, param))
|
||||
|
||||
fp32_to_name = dict()
|
||||
for (name, p), fp32_p in zip(ddp_param_list, self.fp32_params):
|
||||
for p, fp32_p in zip(self.fp16_params, self.fp32_params):
|
||||
if p is not None:
|
||||
name = self.param2name[p]
|
||||
fp32_to_name[fp32_p] = name
|
||||
|
||||
chunk_list = self.chunk_manager.get_chunks(self.fp32_params)
|
||||
|
@ -662,6 +630,60 @@ class ZeroDDP(ColoDDP):
|
|||
if input_name not in local_state:
|
||||
unexpected_keys.append(key)
|
||||
|
||||
def _init_chunks(self, param_order, strict_ddp_mode: bool, cpu_offload: bool, pin_memory: bool):
|
||||
ddp_pg = ColoProcessGroup()
|
||||
for p in param_order.generate():
|
||||
assert isinstance(p, ColoParameter)
|
||||
|
||||
# gather sharded parameters in the strict ddp mode
|
||||
if strict_ddp_mode:
|
||||
if not p.is_replicate():
|
||||
p.set_dist_spec(ReplicaSpec())
|
||||
p.set_process_group(pg=ddp_pg)
|
||||
|
||||
# ignore the parameters with no gradient
|
||||
if not p.requires_grad:
|
||||
self.set_params_to_ignore([p])
|
||||
|
||||
# move ignored parameters to CUDA
|
||||
if is_ddp_ignored(p):
|
||||
p.data = p.data.to(device=get_current_device(), dtype=torch.float16)
|
||||
continue
|
||||
|
||||
# create a fp32 parameter
|
||||
fp32_data = p.data.float()
|
||||
fp32_p = ColoTensor(fp32_data, spec=ColoTensorSpec(p.process_group))
|
||||
# create a fp16 parameter
|
||||
p.data = p.data.half()
|
||||
|
||||
# register the fp16 parameter and fp32 parameter in the chunk manager
|
||||
dp_world_size = p.process_group.dp_world_size()
|
||||
self.chunk_manager.register_tensor(tensor=p,
|
||||
group_type='fp16_param',
|
||||
config_key=dp_world_size,
|
||||
cpu_offload=cpu_offload,
|
||||
pin_memory=pin_memory)
|
||||
self.chunk_manager.register_tensor(tensor=fp32_p,
|
||||
group_type='fp32_param',
|
||||
config_key=dp_world_size,
|
||||
cpu_offload=cpu_offload,
|
||||
pin_memory=pin_memory)
|
||||
|
||||
self.fp16_params.append(p)
|
||||
self.fp32_params.append(fp32_p)
|
||||
self.grads_device[p] = self.gemini_manager.default_device
|
||||
|
||||
self.chunk_manager.close_all_groups()
|
||||
|
||||
for p, fp32_p in zip(self.fp16_params, self.fp32_params):
|
||||
chunk_16 = self.chunk_manager.get_chunk(p)
|
||||
chunk_32 = self.chunk_manager.get_chunk(fp32_p)
|
||||
chunk_32.init_pair(chunk_16)
|
||||
|
||||
# keep gathered chunks are in CUDA
|
||||
if chunk_16.keep_gathered:
|
||||
self.grads_device[p] = get_current_device()
|
||||
|
||||
def _cast_buffers(self):
|
||||
for buffer in self.module.buffers():
|
||||
buffer.data = buffer.cuda()
|
||||
|
|
|
@ -49,6 +49,10 @@ class GeminiDDP(ZeroDDP):
|
|||
all parameters will be compacted into one small chunk.
|
||||
memstats (MemStats, optional) the memory statistics collector by a runtime memory tracer.
|
||||
"""
|
||||
# some ugly hotfix for the compatibility with Lightning
|
||||
if search_range_mb is None:
|
||||
search_range_mb = 32
|
||||
|
||||
chunk_manager = init_chunk_manager(model=module,
|
||||
init_device=device,
|
||||
hidden_dim=hidden_dim,
|
||||
|
|
|
@ -80,13 +80,11 @@ def get_static_torch_model(zero_ddp_model,
|
|||
from colossalai.nn.parallel import ZeroDDP
|
||||
assert isinstance(zero_ddp_model, ZeroDDP)
|
||||
|
||||
state_dict = zero_ddp_model.state_dict(only_rank_0=only_rank_0, strict=False)
|
||||
state_dict = zero_ddp_model.state_dict(only_rank_0=only_rank_0)
|
||||
colo_model = zero_ddp_model.module
|
||||
torch_model = _get_shallow_copy_model(colo_model)
|
||||
|
||||
if not only_rank_0 or dist.get_rank() == 0:
|
||||
# record the mapping relationship between colo parameters and torch parameters
|
||||
colo_to_torch = dict()
|
||||
for (name, colo_module), (_, torch_module) in \
|
||||
zip(_get_dfs_module_list(colo_model), _get_dfs_module_list(torch_model)):
|
||||
# clean the parameter list of the new torch module
|
||||
|
@ -94,17 +92,10 @@ def get_static_torch_model(zero_ddp_model,
|
|||
for sufix_param_name, param in colo_module.named_parameters(recurse=False):
|
||||
# get the full name of the parameter
|
||||
full_param_name = name + ('.' if name else '') + sufix_param_name
|
||||
|
||||
if full_param_name not in state_dict:
|
||||
# this means the parameter is shared by multiple modules
|
||||
# we should use colo_to_torch to get the torch parameter created before
|
||||
assert param in colo_to_torch, f"can not find parameter `{full_param_name}` in the GeminiDDP module"
|
||||
torch_param = colo_to_torch[param]
|
||||
else:
|
||||
# we meet the parameter the first time, just use the state dict to get the data
|
||||
assert full_param_name in state_dict, \
|
||||
f"Can not find parameter `{full_param_name}` in the GeminiDDP module"
|
||||
state_param = state_dict[full_param_name]
|
||||
torch_param = torch.nn.Parameter(state_param.data.to(device=device, dtype=dtype))
|
||||
colo_to_torch[param] = torch_param
|
||||
|
||||
setattr(torch_module, sufix_param_name, torch_param)
|
||||
dist.barrier()
|
||||
|
|
Loading…
Reference in New Issue