mirror of https://github.com/hpcaitech/ColossalAI
[gemini] state dict supports fp16 (#3590)
* [gemini] save state dict support fp16 * [gemini] save state dict shard support fp16 * [gemini] fix state dict * [gemini] fix state dictpull/3603/head
parent
d544ed4345
commit
12eff9eb4c
|
@ -202,7 +202,12 @@ class ZeroDDP(ColoDDP):
|
||||||
for tensor in chunk.get_tensors():
|
for tensor in chunk.get_tensors():
|
||||||
self.grads_device[tensor] = device
|
self.grads_device[tensor] = device
|
||||||
|
|
||||||
def state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: bool = True):
|
def state_dict(self,
|
||||||
|
destination=None,
|
||||||
|
prefix='',
|
||||||
|
keep_vars=False,
|
||||||
|
only_rank_0: bool = True,
|
||||||
|
dtype: torch.dtype = torch.float16):
|
||||||
"""Returns a dictionary containing a whole state of the module.
|
"""Returns a dictionary containing a whole state of the module.
|
||||||
|
|
||||||
Both parameters and persistent buffers (e.g. running averages) are included.
|
Both parameters and persistent buffers (e.g. running averages) are included.
|
||||||
|
@ -221,7 +226,7 @@ class ZeroDDP(ColoDDP):
|
||||||
destination = OrderedDict()
|
destination = OrderedDict()
|
||||||
destination._metadata = OrderedDict()
|
destination._metadata = OrderedDict()
|
||||||
destination._metadata[prefix[:-1]] = local_metadata = dict(version=self._version)
|
destination._metadata[prefix[:-1]] = local_metadata = dict(version=self._version)
|
||||||
self._save_to_state_dict(destination, prefix, keep_vars, only_rank_0)
|
self._save_to_state_dict(destination, prefix, keep_vars, only_rank_0, dtype)
|
||||||
|
|
||||||
for hook in self._state_dict_hooks.values():
|
for hook in self._state_dict_hooks.values():
|
||||||
hook_result = hook(self, destination, prefix, local_metadata)
|
hook_result = hook(self, destination, prefix, local_metadata)
|
||||||
|
@ -229,7 +234,7 @@ class ZeroDDP(ColoDDP):
|
||||||
destination = hook_result
|
destination = hook_result
|
||||||
return destination
|
return destination
|
||||||
|
|
||||||
def _get_chunk_to_save_data(self, chunk: Chunk, only_rank_0: bool) -> Dict:
|
def _get_chunk_to_save_data(self, chunk: Chunk, only_rank_0: bool, dtype: torch.dtype = torch.float16) -> Dict:
|
||||||
"""
|
"""
|
||||||
get gathered chunk content.
|
get gathered chunk content.
|
||||||
|
|
||||||
|
@ -243,6 +248,8 @@ class ZeroDDP(ColoDDP):
|
||||||
# save parameters
|
# save parameters
|
||||||
chunk_to_save_data = dict()
|
chunk_to_save_data = dict()
|
||||||
temp_chunk = get_temp_total_chunk_on_cuda(chunk)
|
temp_chunk = get_temp_total_chunk_on_cuda(chunk)
|
||||||
|
if torch.is_floating_point(temp_chunk):
|
||||||
|
temp_chunk = temp_chunk.to(dtype)
|
||||||
for tensor, tensor_info in chunk.tensors_info.items():
|
for tensor, tensor_info in chunk.tensors_info.items():
|
||||||
record_tensor = torch.empty([0])
|
record_tensor = torch.empty([0])
|
||||||
record_flag = (not only_rank_0) | (dist.get_rank(chunk.torch_pg) == 0)
|
record_flag = (not only_rank_0) | (dist.get_rank(chunk.torch_pg) == 0)
|
||||||
|
@ -255,7 +262,8 @@ class ZeroDDP(ColoDDP):
|
||||||
del temp_chunk
|
del temp_chunk
|
||||||
return chunk_to_save_data
|
return chunk_to_save_data
|
||||||
|
|
||||||
def _get_param_to_save_data(self, param_list: List[torch.nn.Parameter], only_rank_0: bool) -> Dict:
|
def _get_param_to_save_data(self, param_list: List[torch.nn.Parameter], only_rank_0: bool,
|
||||||
|
dtype: torch.dtype) -> Dict:
|
||||||
"""
|
"""
|
||||||
get param content from chunks.
|
get param content from chunks.
|
||||||
|
|
||||||
|
@ -270,10 +278,10 @@ class ZeroDDP(ColoDDP):
|
||||||
param_to_save_data = dict()
|
param_to_save_data = dict()
|
||||||
chunk_list = self.chunk_manager.get_chunks(param_list)
|
chunk_list = self.chunk_manager.get_chunks(param_list)
|
||||||
for chunk in chunk_list:
|
for chunk in chunk_list:
|
||||||
param_to_save_data.update(self._get_chunk_to_save_data(chunk, only_rank_0))
|
param_to_save_data.update(self._get_chunk_to_save_data(chunk, only_rank_0, dtype))
|
||||||
return param_to_save_data
|
return param_to_save_data
|
||||||
|
|
||||||
def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True):
|
def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True, dtype=torch.float16):
|
||||||
r"""Saves module state to `destination` dictionary, containing a state
|
r"""Saves module state to `destination` dictionary, containing a state
|
||||||
of the module, but not its descendants. This is called on every
|
of the module, but not its descendants. This is called on every
|
||||||
submodule in :meth:`~torch.nn.Module.state_dict`.
|
submodule in :meth:`~torch.nn.Module.state_dict`.
|
||||||
|
@ -289,7 +297,8 @@ class ZeroDDP(ColoDDP):
|
||||||
assert keep_vars is False, "`state_dict` with parameter, `keep_vars=True`, is not supported now."
|
assert keep_vars is False, "`state_dict` with parameter, `keep_vars=True`, is not supported now."
|
||||||
|
|
||||||
# get copies of fp32 parameters in CPU
|
# get copies of fp32 parameters in CPU
|
||||||
param_to_save_data = self._get_param_to_save_data(self.fp32_params, only_rank_0)
|
# as memory of fp16_params may be reused by grad, it's not reliable, we should use fp32_params and convert to fp16
|
||||||
|
param_to_save_data = self._get_param_to_save_data(self.fp32_params, only_rank_0, dtype)
|
||||||
# get the mapping between copies and fp16 parameters
|
# get the mapping between copies and fp16 parameters
|
||||||
p_mapping = dict()
|
p_mapping = dict()
|
||||||
for p, fp32_p in zip(self.fp16_params, self.fp32_params):
|
for p, fp32_p in zip(self.fp16_params, self.fp32_params):
|
||||||
|
@ -574,7 +583,8 @@ class ZeroDDP(ColoDDP):
|
||||||
prefix: str = '',
|
prefix: str = '',
|
||||||
keep_vars: bool = False,
|
keep_vars: bool = False,
|
||||||
max_shard_size: int = 1024,
|
max_shard_size: int = 1024,
|
||||||
only_rank_0: bool = True) -> Iterator[OrderedDict]:
|
only_rank_0: bool = True,
|
||||||
|
dtype: torch.dtype = torch.float16) -> Iterator[OrderedDict]:
|
||||||
"""Returns dictionaries containing a whole state of the module one by one. The max size of dictionary shard is specified by ``max_shard_size``.
|
"""Returns dictionaries containing a whole state of the module one by one. The max size of dictionary shard is specified by ``max_shard_size``.
|
||||||
|
|
||||||
Both parameters and persistent buffers (e.g. running averages) are included.
|
Both parameters and persistent buffers (e.g. running averages) are included.
|
||||||
|
@ -607,10 +617,11 @@ class ZeroDDP(ColoDDP):
|
||||||
# deal with ddp ignored parameters
|
# deal with ddp ignored parameters
|
||||||
gathered_param = param if keep_vars else param.detach()
|
gathered_param = param if keep_vars else param.detach()
|
||||||
else:
|
else:
|
||||||
|
# as memory of fp16 param may be reused, we should use fp32 param and then convert to fp16
|
||||||
fp32_param = fp16_to_fp32[param]
|
fp32_param = fp16_to_fp32[param]
|
||||||
if fp32_param not in gathered_param_buffer:
|
if fp32_param not in gathered_param_buffer:
|
||||||
chunk = self.chunk_manager.get_chunk(fp32_param)
|
chunk = self.chunk_manager.get_chunk(fp32_param)
|
||||||
gathered_param_buffer.update(self._get_chunk_to_save_data(chunk, only_rank_0))
|
gathered_param_buffer.update(self._get_chunk_to_save_data(chunk, only_rank_0, dtype))
|
||||||
gathered_param = gathered_param_buffer.pop(fp32_param)
|
gathered_param = gathered_param_buffer.pop(fp32_param)
|
||||||
|
|
||||||
block = sharder.append(prefix + name, gathered_param)
|
block = sharder.append(prefix + name, gathered_param)
|
||||||
|
|
Loading…
Reference in New Issue