mirror of https://github.com/hpcaitech/ColossalAI
[hotfix] fix lightning error (#2529)
parent
b55deb0662
commit
a4ed9125ac
|
@ -5,6 +5,7 @@ from typing import Dict, Iterable, List, Optional, Set
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
from colossalai.gemini.chunk import Chunk, ChunkManager, TensorState
|
from colossalai.gemini.chunk import Chunk, ChunkManager, TensorState
|
||||||
from colossalai.gemini.gemini_mgr import GeminiManager
|
from colossalai.gemini.gemini_mgr import GeminiManager
|
||||||
|
@ -218,11 +219,15 @@ class ZeroDDP(ColoDDP):
|
||||||
self.chunk_manager: ChunkManager = gemini_manager.chunk_manager
|
self.chunk_manager: ChunkManager = gemini_manager.chunk_manager
|
||||||
self.force_outputs_fp32 = force_outputs_fp32
|
self.force_outputs_fp32 = force_outputs_fp32
|
||||||
self.param_op_hook = GeminiZeROHook(gemini_manager)
|
self.param_op_hook = GeminiZeROHook(gemini_manager)
|
||||||
self.fp32_params: List[ColoTensor] = []
|
self.fp32_params: List[ColoTensor] = list()
|
||||||
|
self.fp16_params: List[ColoParameter] = list()
|
||||||
self.overflow_counter = 0
|
self.overflow_counter = 0
|
||||||
self.grads_device: Dict[torch.Tensor, torch.device] = {}
|
self.grads_device: Dict[torch.Tensor, torch.device] = dict()
|
||||||
|
self.param2name: Dict[nn.Parameter, str] = dict()
|
||||||
|
self.name2param: Dict[str, nn.Parameter] = dict()
|
||||||
|
|
||||||
cpu_offload = self.gemini_manager.policy_name != 'cuda'
|
self._cast_buffers()
|
||||||
|
self._logger = get_dist_logger()
|
||||||
|
|
||||||
if self.gemini_manager._premade_memstats_:
|
if self.gemini_manager._premade_memstats_:
|
||||||
# build chunk in param runtime visited order.
|
# build chunk in param runtime visited order.
|
||||||
|
@ -234,50 +239,17 @@ class ZeroDDP(ColoDDP):
|
||||||
for p in module.parameters():
|
for p in module.parameters():
|
||||||
param_order.append(p)
|
param_order.append(p)
|
||||||
|
|
||||||
ddp_pg = ColoProcessGroup()
|
self._init_chunks(param_order=param_order,
|
||||||
for p in param_order.generate():
|
strict_ddp_mode=strict_ddp_mode,
|
||||||
assert isinstance(p, ColoParameter)
|
cpu_offload=self.gemini_manager.policy_name != 'cuda',
|
||||||
|
pin_memory=pin_memory)
|
||||||
|
|
||||||
if strict_ddp_mode:
|
for name, param in module.named_parameters():
|
||||||
if not p.is_replicate():
|
self.param2name[param] = name
|
||||||
p.set_dist_spec(ReplicaSpec())
|
for m_name, m_var in module.named_modules():
|
||||||
p.set_process_group(pg=ddp_pg)
|
for p_name, p_var in m_var.named_parameters(recurse=False):
|
||||||
|
param_name = m_name + '.' + p_name if m_name else p_name
|
||||||
if is_ddp_ignored(p):
|
self.name2param[param_name] = p_var
|
||||||
p.data = p.data.to(device=get_current_device(), dtype=torch.float16)
|
|
||||||
continue
|
|
||||||
|
|
||||||
fp32_data = p.data.float()
|
|
||||||
fp32_p = ColoTensor(fp32_data, spec=ColoTensorSpec(p.process_group))
|
|
||||||
p.data = p.data.half()
|
|
||||||
dp_world_size = p.process_group.dp_world_size()
|
|
||||||
self.chunk_manager.register_tensor(tensor=p,
|
|
||||||
group_type='fp16_param',
|
|
||||||
config_key=dp_world_size,
|
|
||||||
cpu_offload=cpu_offload,
|
|
||||||
pin_memory=pin_memory)
|
|
||||||
self.chunk_manager.register_tensor(tensor=fp32_p,
|
|
||||||
group_type='fp32_param',
|
|
||||||
config_key=dp_world_size,
|
|
||||||
cpu_offload=cpu_offload,
|
|
||||||
pin_memory=pin_memory)
|
|
||||||
self.fp32_params.append(fp32_p)
|
|
||||||
self.grads_device[p] = self.gemini_manager.default_device
|
|
||||||
|
|
||||||
self.chunk_manager.close_all_groups()
|
|
||||||
self._cast_buffers()
|
|
||||||
|
|
||||||
params_list = [p for p in param_order.generate() if not is_ddp_ignored(p)]
|
|
||||||
for p, fp32_p in zip(params_list, self.fp32_params):
|
|
||||||
chunk_16 = self.chunk_manager.get_chunk(p)
|
|
||||||
chunk_32 = self.chunk_manager.get_chunk(fp32_p)
|
|
||||||
chunk_32.init_pair(chunk_16)
|
|
||||||
|
|
||||||
# keep gathered chunks are in CUDA
|
|
||||||
if chunk_16.keep_gathered:
|
|
||||||
self.grads_device[p] = get_current_device()
|
|
||||||
|
|
||||||
self._logger = get_dist_logger()
|
|
||||||
|
|
||||||
def _post_forward(self):
|
def _post_forward(self):
|
||||||
"""This function is only triggered for inference.
|
"""This function is only triggered for inference.
|
||||||
|
@ -318,10 +290,23 @@ class ZeroDDP(ColoDDP):
|
||||||
continue
|
continue
|
||||||
p.grad = None
|
p.grad = None
|
||||||
|
|
||||||
|
def _pre_bacward(self):
|
||||||
|
# set a visit label for all parameters
|
||||||
|
# the label is used to check whether the parameter is correctly reduced
|
||||||
|
for param in self.param2name:
|
||||||
|
if not is_ddp_ignored(param):
|
||||||
|
setattr(param, "_gemini_reduced", False)
|
||||||
|
|
||||||
def _post_backward(self):
|
def _post_backward(self):
|
||||||
if self.chunk_manager.accessed_mem != 0:
|
if self.chunk_manager.accessed_mem != 0:
|
||||||
|
error_params = ["Reduction failed at followed parameters:"]
|
||||||
|
for param in self.param2name:
|
||||||
|
if not is_ddp_ignored(param) and not getattr(param, "_gemini_reduced"):
|
||||||
|
error_params.append(self.param2name[param])
|
||||||
|
error_str = "\n\t".join(error_params)
|
||||||
raise RuntimeError("ZERO DDP error: the synchronization of gradients doesn't exit properly.",
|
raise RuntimeError("ZERO DDP error: the synchronization of gradients doesn't exit properly.",
|
||||||
"The most possible reason is that the model is not compatible with ZeroDDP.")
|
"The most possible reason is that the model is not compatible with ZeroDDP.\n",
|
||||||
|
f"{error_str}")
|
||||||
self._setup_grads_ptr()
|
self._setup_grads_ptr()
|
||||||
self._logger.debug(
|
self._logger.debug(
|
||||||
f'comp cuda demand time: {self.gemini_manager._comp_cuda_demand_time}, layout time: {self.gemini_manager._layout_time}, evict time: {self.gemini_manager._evict_time}, CPU->CUDA vol: {self.gemini_manager._h2d_volume}B, CUDA->CPU vol: {self.gemini_manager._d2h_volume}'
|
f'comp cuda demand time: {self.gemini_manager._comp_cuda_demand_time}, layout time: {self.gemini_manager._layout_time}, evict time: {self.gemini_manager._evict_time}, CPU->CUDA vol: {self.gemini_manager._h2d_volume}B, CUDA->CPU vol: {self.gemini_manager._d2h_volume}'
|
||||||
|
@ -329,6 +314,7 @@ class ZeroDDP(ColoDDP):
|
||||||
self.gemini_manager.post_iter()
|
self.gemini_manager.post_iter()
|
||||||
|
|
||||||
def backward(self, loss: torch.Tensor):
|
def backward(self, loss: torch.Tensor):
|
||||||
|
self._pre_bacward()
|
||||||
with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(self.param_op_hook):
|
with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(self.param_op_hook):
|
||||||
loss.backward()
|
loss.backward()
|
||||||
self._post_backward()
|
self._post_backward()
|
||||||
|
@ -343,7 +329,9 @@ class ZeroDDP(ColoDDP):
|
||||||
free_storage(empty_grad)
|
free_storage(empty_grad)
|
||||||
with torch._C.DisableTorchFunction():
|
with torch._C.DisableTorchFunction():
|
||||||
chunk = self.chunk_manager.get_chunk(p)
|
chunk = self.chunk_manager.get_chunk(p)
|
||||||
assert chunk.tensors_info[p].state == TensorState.HOLD_AFTER_BWD
|
if chunk.tensors_info[p].state != TensorState.HOLD_AFTER_BWD:
|
||||||
|
raise RuntimeError(f"Parameter `{self.param2name[p]}` failed at the gradient reduction. "
|
||||||
|
"Some unsupported torch function is operated upon this parameter.")
|
||||||
self.chunk_manager.trans_tensor_state(p, TensorState.READY_FOR_REDUCE)
|
self.chunk_manager.trans_tensor_state(p, TensorState.READY_FOR_REDUCE)
|
||||||
chunk.copy_tensor_to_chunk_slice(p, grad)
|
chunk.copy_tensor_to_chunk_slice(p, grad)
|
||||||
reduced = self.chunk_manager.reduce_chunk(chunk)
|
reduced = self.chunk_manager.reduce_chunk(chunk)
|
||||||
|
@ -367,30 +355,7 @@ class ZeroDDP(ColoDDP):
|
||||||
for tensor in chunk.get_tensors():
|
for tensor in chunk.get_tensors():
|
||||||
self.grads_device[tensor] = device
|
self.grads_device[tensor] = device
|
||||||
|
|
||||||
def state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: bool = True, strict: bool = True):
|
def state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: bool = True):
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
strict (bool): whether to reture the whole model state as the pytorch `Module.state_dict()`
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict:
|
|
||||||
a dictionary containing a whole state of the module
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
>>> module.state_dict().keys()
|
|
||||||
['bias', 'weight']
|
|
||||||
"""
|
|
||||||
if strict:
|
|
||||||
assert keep_vars is False, "`state_dict` with parameter, `keep_vars=True`, is not supported now."
|
|
||||||
torch_model = get_static_torch_model(zero_ddp_model=self, only_rank_0=only_rank_0)
|
|
||||||
return torch_model.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
|
|
||||||
return self._non_strict_state_dict(destination=destination,
|
|
||||||
prefix=prefix,
|
|
||||||
keep_vars=keep_vars,
|
|
||||||
only_rank_0=only_rank_0)
|
|
||||||
|
|
||||||
def _non_strict_state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: bool = True):
|
|
||||||
"""Returns a dictionary containing a whole state of the module.
|
"""Returns a dictionary containing a whole state of the module.
|
||||||
|
|
||||||
Both parameters and persistent buffers (e.g. running averages) are included.
|
Both parameters and persistent buffers (e.g. running averages) are included.
|
||||||
|
@ -461,19 +426,24 @@ class ZeroDDP(ColoDDP):
|
||||||
"""
|
"""
|
||||||
assert keep_vars is False, "`state_dict` with parameter, `keep_vars=True`, is not supported now."
|
assert keep_vars is False, "`state_dict` with parameter, `keep_vars=True`, is not supported now."
|
||||||
|
|
||||||
|
# get copies of fp32 parameters in CPU
|
||||||
param_to_save_data = self._get_param_to_save_data(self.fp32_params, only_rank_0)
|
param_to_save_data = self._get_param_to_save_data(self.fp32_params, only_rank_0)
|
||||||
ddp_param_list = []
|
# get the mapping between copies and fp16 parameters
|
||||||
for name, param in self.named_parameters():
|
p_mapping = dict()
|
||||||
if is_ddp_ignored(param):
|
for p, fp32_p in zip(self.fp16_params, self.fp32_params):
|
||||||
# deal with ddp ignored parameters
|
name = self.param2name[p]
|
||||||
destination[prefix + name] = param if keep_vars else param.detach()
|
assert fp32_p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name)
|
||||||
else:
|
record_parameter = param_to_save_data[fp32_p]
|
||||||
ddp_param_list.append((name, param))
|
p_mapping[p] = record_parameter
|
||||||
for (name, p), fp32_p in zip(ddp_param_list, self.fp32_params):
|
for name, param in self.name2param.items():
|
||||||
if p is not None:
|
if param is not None:
|
||||||
assert fp32_p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name)
|
if is_ddp_ignored(param):
|
||||||
record_parameter = param_to_save_data[fp32_p]
|
# deal with ddp ignored parameters
|
||||||
destination[prefix + name] = record_parameter
|
destination[prefix + name] = param if keep_vars else param.detach()
|
||||||
|
else:
|
||||||
|
destination[prefix + name] = p_mapping[param]
|
||||||
|
del p_mapping
|
||||||
|
del param_to_save_data
|
||||||
|
|
||||||
# save all buffers
|
# save all buffers
|
||||||
for name, buf in self.named_buffers():
|
for name, buf in self.named_buffers():
|
||||||
|
@ -605,17 +575,15 @@ class ZeroDDP(ColoDDP):
|
||||||
def load_fp32_parameter(chunk_slice, data):
|
def load_fp32_parameter(chunk_slice, data):
|
||||||
chunk_slice.copy_(data.flatten())
|
chunk_slice.copy_(data.flatten())
|
||||||
|
|
||||||
ddp_param_list = []
|
|
||||||
for name, param in self.named_parameters():
|
for name, param in self.named_parameters():
|
||||||
if is_ddp_ignored(param):
|
if is_ddp_ignored(param):
|
||||||
# deal with ddp ignored parameters
|
# deal with ddp ignored parameters
|
||||||
load(name, param, param.copy_)
|
load(name, param, param.copy_)
|
||||||
else:
|
|
||||||
ddp_param_list.append((name, param))
|
|
||||||
|
|
||||||
fp32_to_name = dict()
|
fp32_to_name = dict()
|
||||||
for (name, p), fp32_p in zip(ddp_param_list, self.fp32_params):
|
for p, fp32_p in zip(self.fp16_params, self.fp32_params):
|
||||||
if p is not None:
|
if p is not None:
|
||||||
|
name = self.param2name[p]
|
||||||
fp32_to_name[fp32_p] = name
|
fp32_to_name[fp32_p] = name
|
||||||
|
|
||||||
chunk_list = self.chunk_manager.get_chunks(self.fp32_params)
|
chunk_list = self.chunk_manager.get_chunks(self.fp32_params)
|
||||||
|
@ -662,6 +630,60 @@ class ZeroDDP(ColoDDP):
|
||||||
if input_name not in local_state:
|
if input_name not in local_state:
|
||||||
unexpected_keys.append(key)
|
unexpected_keys.append(key)
|
||||||
|
|
||||||
|
def _init_chunks(self, param_order, strict_ddp_mode: bool, cpu_offload: bool, pin_memory: bool):
|
||||||
|
ddp_pg = ColoProcessGroup()
|
||||||
|
for p in param_order.generate():
|
||||||
|
assert isinstance(p, ColoParameter)
|
||||||
|
|
||||||
|
# gather sharded parameters in the strict ddp mode
|
||||||
|
if strict_ddp_mode:
|
||||||
|
if not p.is_replicate():
|
||||||
|
p.set_dist_spec(ReplicaSpec())
|
||||||
|
p.set_process_group(pg=ddp_pg)
|
||||||
|
|
||||||
|
# ignore the parameters with no gradient
|
||||||
|
if not p.requires_grad:
|
||||||
|
self.set_params_to_ignore([p])
|
||||||
|
|
||||||
|
# move ignored parameters to CUDA
|
||||||
|
if is_ddp_ignored(p):
|
||||||
|
p.data = p.data.to(device=get_current_device(), dtype=torch.float16)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# create a fp32 parameter
|
||||||
|
fp32_data = p.data.float()
|
||||||
|
fp32_p = ColoTensor(fp32_data, spec=ColoTensorSpec(p.process_group))
|
||||||
|
# create a fp16 parameter
|
||||||
|
p.data = p.data.half()
|
||||||
|
|
||||||
|
# register the fp16 parameter and fp32 parameter in the chunk manager
|
||||||
|
dp_world_size = p.process_group.dp_world_size()
|
||||||
|
self.chunk_manager.register_tensor(tensor=p,
|
||||||
|
group_type='fp16_param',
|
||||||
|
config_key=dp_world_size,
|
||||||
|
cpu_offload=cpu_offload,
|
||||||
|
pin_memory=pin_memory)
|
||||||
|
self.chunk_manager.register_tensor(tensor=fp32_p,
|
||||||
|
group_type='fp32_param',
|
||||||
|
config_key=dp_world_size,
|
||||||
|
cpu_offload=cpu_offload,
|
||||||
|
pin_memory=pin_memory)
|
||||||
|
|
||||||
|
self.fp16_params.append(p)
|
||||||
|
self.fp32_params.append(fp32_p)
|
||||||
|
self.grads_device[p] = self.gemini_manager.default_device
|
||||||
|
|
||||||
|
self.chunk_manager.close_all_groups()
|
||||||
|
|
||||||
|
for p, fp32_p in zip(self.fp16_params, self.fp32_params):
|
||||||
|
chunk_16 = self.chunk_manager.get_chunk(p)
|
||||||
|
chunk_32 = self.chunk_manager.get_chunk(fp32_p)
|
||||||
|
chunk_32.init_pair(chunk_16)
|
||||||
|
|
||||||
|
# keep gathered chunks are in CUDA
|
||||||
|
if chunk_16.keep_gathered:
|
||||||
|
self.grads_device[p] = get_current_device()
|
||||||
|
|
||||||
def _cast_buffers(self):
|
def _cast_buffers(self):
|
||||||
for buffer in self.module.buffers():
|
for buffer in self.module.buffers():
|
||||||
buffer.data = buffer.cuda()
|
buffer.data = buffer.cuda()
|
||||||
|
|
|
@ -49,6 +49,10 @@ class GeminiDDP(ZeroDDP):
|
||||||
all parameters will be compacted into one small chunk.
|
all parameters will be compacted into one small chunk.
|
||||||
memstats (MemStats, optional) the memory statistics collector by a runtime memory tracer.
|
memstats (MemStats, optional) the memory statistics collector by a runtime memory tracer.
|
||||||
"""
|
"""
|
||||||
|
# some ugly hotfix for the compatibility with Lightning
|
||||||
|
if search_range_mb is None:
|
||||||
|
search_range_mb = 32
|
||||||
|
|
||||||
chunk_manager = init_chunk_manager(model=module,
|
chunk_manager = init_chunk_manager(model=module,
|
||||||
init_device=device,
|
init_device=device,
|
||||||
hidden_dim=hidden_dim,
|
hidden_dim=hidden_dim,
|
||||||
|
|
|
@ -80,13 +80,11 @@ def get_static_torch_model(zero_ddp_model,
|
||||||
from colossalai.nn.parallel import ZeroDDP
|
from colossalai.nn.parallel import ZeroDDP
|
||||||
assert isinstance(zero_ddp_model, ZeroDDP)
|
assert isinstance(zero_ddp_model, ZeroDDP)
|
||||||
|
|
||||||
state_dict = zero_ddp_model.state_dict(only_rank_0=only_rank_0, strict=False)
|
state_dict = zero_ddp_model.state_dict(only_rank_0=only_rank_0)
|
||||||
colo_model = zero_ddp_model.module
|
colo_model = zero_ddp_model.module
|
||||||
torch_model = _get_shallow_copy_model(colo_model)
|
torch_model = _get_shallow_copy_model(colo_model)
|
||||||
|
|
||||||
if not only_rank_0 or dist.get_rank() == 0:
|
if not only_rank_0 or dist.get_rank() == 0:
|
||||||
# record the mapping relationship between colo parameters and torch parameters
|
|
||||||
colo_to_torch = dict()
|
|
||||||
for (name, colo_module), (_, torch_module) in \
|
for (name, colo_module), (_, torch_module) in \
|
||||||
zip(_get_dfs_module_list(colo_model), _get_dfs_module_list(torch_model)):
|
zip(_get_dfs_module_list(colo_model), _get_dfs_module_list(torch_model)):
|
||||||
# clean the parameter list of the new torch module
|
# clean the parameter list of the new torch module
|
||||||
|
@ -94,17 +92,10 @@ def get_static_torch_model(zero_ddp_model,
|
||||||
for sufix_param_name, param in colo_module.named_parameters(recurse=False):
|
for sufix_param_name, param in colo_module.named_parameters(recurse=False):
|
||||||
# get the full name of the parameter
|
# get the full name of the parameter
|
||||||
full_param_name = name + ('.' if name else '') + sufix_param_name
|
full_param_name = name + ('.' if name else '') + sufix_param_name
|
||||||
|
assert full_param_name in state_dict, \
|
||||||
if full_param_name not in state_dict:
|
f"Can not find parameter `{full_param_name}` in the GeminiDDP module"
|
||||||
# this means the parameter is shared by multiple modules
|
state_param = state_dict[full_param_name]
|
||||||
# we should use colo_to_torch to get the torch parameter created before
|
torch_param = torch.nn.Parameter(state_param.data.to(device=device, dtype=dtype))
|
||||||
assert param in colo_to_torch, f"can not find parameter `{full_param_name}` in the GeminiDDP module"
|
|
||||||
torch_param = colo_to_torch[param]
|
|
||||||
else:
|
|
||||||
# we meet the parameter the first time, just use the state dict to get the data
|
|
||||||
state_param = state_dict[full_param_name]
|
|
||||||
torch_param = torch.nn.Parameter(state_param.data.to(device=device, dtype=dtype))
|
|
||||||
colo_to_torch[param] = torch_param
|
|
||||||
|
|
||||||
setattr(torch_module, sufix_param_name, torch_param)
|
setattr(torch_module, sufix_param_name, torch_param)
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
|
|
Loading…
Reference in New Issue