mirror of https://github.com/hpcaitech/ColossalAI
[gemini] support amp o3 for gemini (#4872)
* [gemini] support no reuse fp16 chunk * [gemini] support no master weight for optim * [gemini] support no master weight for gemini ddp * [test] update gemini tests * [test] update gemini tests * [plugin] update gemini plugin * [test] fix gemini checkpointio test * [test] fix gemini checkpoint iopull/4864/head
parent
c1fab951e7
commit
df63564184
|
@ -97,7 +97,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
|||
|
||||
Path(checkpoint_path).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
state_dict_shard = model.state_dict_shard(max_shard_size=max_shard_size, only_rank_0=True, dtype=torch.float32)
|
||||
state_dict_shard = model.state_dict_shard(max_shard_size=max_shard_size, only_rank_0=True)
|
||||
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
|
||||
index_file = CheckpointIndexFile(checkpoint_path)
|
||||
|
||||
|
@ -257,6 +257,7 @@ class GeminiPlugin(DPPluginBase):
|
|||
warmup_non_model_data_ratio (float, optional): ratio of expected non-model data memory during warmup. Only for "auto" placement. Defaults to 0.8.
|
||||
steady_cuda_cap_ratio (float, optional): ratio of allowed cuda capacity for model data during steady state. Only for "auto" placement. Defaults to 0.9.
|
||||
precision (str, optional): precision. Support 'fp16' and 'bf16'. Defaults to 'fp16'.
|
||||
master_weights (bool, optional): master weights. Defaults to True.
|
||||
pin_memory (bool, optional): use pin memory on CPU. Defaults to False.
|
||||
force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False.
|
||||
strict_ddp_mode (bool, optional): use strict ddp mode (only use dp without other parallelism). Defaults to False.
|
||||
|
@ -296,6 +297,7 @@ class GeminiPlugin(DPPluginBase):
|
|||
warmup_non_model_data_ratio: float = 0.8, # only for auto placement
|
||||
steady_cuda_cap_ratio: float = 0.9, # only for auto placement
|
||||
precision: str = "fp16",
|
||||
master_weights: bool = True,
|
||||
pin_memory: bool = False,
|
||||
force_outputs_fp32: bool = False,
|
||||
strict_ddp_mode: bool = False,
|
||||
|
@ -334,6 +336,7 @@ class GeminiPlugin(DPPluginBase):
|
|||
min_chunk_size_m=min_chunk_size_m,
|
||||
memstats=memstats,
|
||||
mixed_precision=PRECISION_STR_TO_DTYPE[precision],
|
||||
master_weights=master_weights,
|
||||
)
|
||||
self.zero_optim_config = dict(
|
||||
gpu_margin_mem_ratio=gpu_margin_mem_ratio,
|
||||
|
|
|
@ -132,9 +132,6 @@ class CPUAdam(NVMeOptimizer):
|
|||
target_device = p.device
|
||||
if len(state) == 0:
|
||||
state["step"] = 0
|
||||
|
||||
# FIXME(ver217): CPU adam kernel only supports fp32 states now
|
||||
assert p.dtype is torch.float, "CPUAdam only support fp32 parameters"
|
||||
# gradient momentums
|
||||
state["exp_avg"] = torch.zeros_like(p, device=target_device)
|
||||
# gradient variances
|
||||
|
@ -149,7 +146,8 @@ class CPUAdam(NVMeOptimizer):
|
|||
assert state["exp_avg"].device.type == "cpu", "exp_avg should stay on cpu"
|
||||
assert state["exp_avg_sq"].device.type == "cpu", "exp_avg should stay on cpu"
|
||||
self._pre_update(p, "exp_avg", "exp_avg_sq")
|
||||
if p.grad.dtype is torch.bfloat16:
|
||||
# FIXME(ver217): CPU adam kernel only supports fp32 states now
|
||||
if p.grad.dtype is torch.bfloat16 or p.dtype is not torch.float:
|
||||
# cpu adam kernel does not support bf16 now
|
||||
bias_correction1 = 1 - beta1 ** state["step"]
|
||||
bias_correction2 = 1 - beta2 ** state["step"]
|
||||
|
|
|
@ -108,9 +108,6 @@ class HybridAdam(CPUAdam):
|
|||
target_device = p.device
|
||||
if len(state) == 0:
|
||||
state["step"] = 0
|
||||
|
||||
# FIXME(ver217): CPU adam kernel only supports fp32 states now
|
||||
assert p.dtype is torch.float, "HybridAdam only support fp32 parameters"
|
||||
# gradient momentums
|
||||
state["exp_avg"] = torch.zeros_like(p, device=target_device)
|
||||
# gradient variances
|
||||
|
@ -125,7 +122,8 @@ class HybridAdam(CPUAdam):
|
|||
assert state["exp_avg"].device.type == "cpu", "exp_avg should stay on cpu"
|
||||
assert state["exp_avg_sq"].device.type == "cpu", "exp_avg should stay on cpu"
|
||||
self._pre_update(p, "exp_avg", "exp_avg_sq")
|
||||
if p.grad.dtype is torch.bfloat16:
|
||||
# FIXME(ver217): CPU adam kernel only supports fp32 states now
|
||||
if p.grad.dtype is torch.bfloat16 or p.dtype is not torch.float:
|
||||
# cpu adam kernel does not support bf16 now
|
||||
bias_correction1 = 1 - beta1 ** state["step"]
|
||||
bias_correction2 = 1 - beta2 ** state["step"]
|
||||
|
|
|
@ -40,7 +40,7 @@ def assert_equal_in_group(tensor: Tensor, process_group: ProcessGroup = None):
|
|||
assert torch.all(a == b), f"expected tensors on rank {i} and {i + 1} to be equal but they are not, {a} vs {b}"
|
||||
|
||||
|
||||
def check_state_dict_equal(d1: OrderedDict, d2: OrderedDict, ignore_device: bool = True):
|
||||
def check_state_dict_equal(d1: OrderedDict, d2: OrderedDict, ignore_device: bool = True, ignore_dtype: bool = False):
|
||||
assert len(list(d1.keys())) == len(
|
||||
list(d2.keys())
|
||||
), f"Number of keys unequal: {len(list(d1.keys()))} vs {len(list(d2.keys()))}"
|
||||
|
@ -58,6 +58,8 @@ def check_state_dict_equal(d1: OrderedDict, d2: OrderedDict, ignore_device: bool
|
|||
if not ignore_device:
|
||||
v1_i = v1_i.to("cpu")
|
||||
v2_i = v2_i.to("cpu")
|
||||
if ignore_dtype:
|
||||
v1_i = v1_i.to(v2_i.dtype)
|
||||
assert_close_loose(v1_i, v2_i)
|
||||
elif isinstance(v1_i, dict):
|
||||
assert isinstance(v2_i, dict)
|
||||
|
@ -69,6 +71,8 @@ def check_state_dict_equal(d1: OrderedDict, d2: OrderedDict, ignore_device: bool
|
|||
if not ignore_device:
|
||||
v1 = v1.to("cpu")
|
||||
v2 = v2.to("cpu")
|
||||
if ignore_dtype:
|
||||
v1 = v1.to(v2.dtype)
|
||||
assert_close_loose(v1, v2)
|
||||
else:
|
||||
assert v1 == v2, f"{v1} not equals to {v2}"
|
||||
|
|
|
@ -160,6 +160,8 @@ class Chunk:
|
|||
self.l2_norm_flag = False
|
||||
self.l2_norm = None
|
||||
|
||||
self.grad_chunk = None
|
||||
|
||||
@property
|
||||
def memory_usage(self) -> Dict[str, int]:
|
||||
cuda_memory = 0
|
||||
|
@ -414,7 +416,9 @@ class Chunk:
|
|||
return
|
||||
self.__update_one_tensor_info(self.tensors_info[tensor], tensor_state)
|
||||
|
||||
def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data_slice: torch.Tensor) -> None:
|
||||
def copy_tensor_to_chunk_slice(
|
||||
self, tensor: torch.Tensor, data_slice: torch.Tensor, update_ptr: bool = True
|
||||
) -> None:
|
||||
"""
|
||||
Copy data slice to the memory space indexed by the input tensor in the chunk.
|
||||
|
||||
|
@ -427,7 +431,8 @@ class Chunk:
|
|||
|
||||
tensor_info = self.tensors_info[tensor]
|
||||
self.cuda_global_chunk[tensor_info.offset : tensor_info.end].copy_(data_slice.data.flatten())
|
||||
tensor.data = self.cuda_global_chunk[tensor_info.offset : tensor_info.end].view(tensor.shape)
|
||||
if update_ptr:
|
||||
tensor.data = self.cuda_global_chunk[tensor_info.offset : tensor_info.end].view(tensor.shape)
|
||||
|
||||
def get_valid_length(self) -> int:
|
||||
"""Get the valid length of the chunk's payload."""
|
||||
|
@ -577,3 +582,46 @@ class Chunk:
|
|||
output.append("\t\t# of {}: {}\n".format(st, self.tensor_state_cnter[st]))
|
||||
|
||||
return "".join(output)
|
||||
|
||||
def init_grad_chunk(self) -> "Chunk":
|
||||
"""Init grad chunk. This should be called in grad handler.
|
||||
|
||||
Returns:
|
||||
Chunk: Grad chunk
|
||||
"""
|
||||
if self.grad_chunk is None:
|
||||
# grad chunk is not initialized
|
||||
grad_chunk = Chunk(
|
||||
chunk_size=self.chunk_size,
|
||||
process_group=self.torch_pg,
|
||||
dtype=self.dtype,
|
||||
keep_gathered=self.keep_gathered,
|
||||
pin_memory=self.pin_memory,
|
||||
)
|
||||
grad_chunk.num_tensors = self.num_tensors
|
||||
grad_chunk.utilized_size = self.utilized_size
|
||||
grad_chunk.tensor_state_cnter[TensorState.HOLD] = self.num_tensors
|
||||
for tensor, state in self.tensors_info.items():
|
||||
grad_chunk.tensors_info[tensor] = TensorInfo(TensorState.HOLD, state.offset, state.end)
|
||||
|
||||
grad_chunk.valid_end = self.valid_end
|
||||
|
||||
if grad_chunk.chunk_temp.device.type == "cpu":
|
||||
grad_chunk.cuda_global_chunk = grad_chunk.chunk_temp.to(get_current_device())
|
||||
else:
|
||||
grad_chunk.cuda_global_chunk = grad_chunk.chunk_temp
|
||||
grad_chunk.chunk_temp = None
|
||||
|
||||
if grad_chunk.pin_memory:
|
||||
grad_chunk.cpu_shard = torch.empty(
|
||||
grad_chunk.shard_size, dtype=grad_chunk.dtype, pin_memory=grad_chunk.pin_memory
|
||||
)
|
||||
|
||||
self.grad_chunk = grad_chunk
|
||||
else:
|
||||
# grad chunk is initialized, just reallocate cuda global chunk
|
||||
self.grad_chunk.cuda_shard = None
|
||||
self.grad_chunk.is_gathered = True
|
||||
alloc_storage(self.grad_chunk.cuda_global_chunk)
|
||||
|
||||
return self.grad_chunk
|
||||
|
|
|
@ -245,3 +245,13 @@ class ChunkManager:
|
|||
chunk.release_chunk()
|
||||
self.accessed_chunks.remove(chunk)
|
||||
self.accessed_mem -= chunk.chunk_mem
|
||||
|
||||
def init_grad_chunk(self, chunk: Chunk) -> Chunk:
|
||||
if chunk.grad_chunk is not None:
|
||||
self.__sub_memory_usage(chunk.grad_chunk.memory_usage)
|
||||
grad_chunk = chunk.init_grad_chunk()
|
||||
self.__add_memory_usage(grad_chunk.memory_usage)
|
||||
if grad_chunk not in self.accessed_chunks:
|
||||
self.accessed_chunks.add(grad_chunk)
|
||||
self.accessed_mem += grad_chunk.chunk_mem
|
||||
return grad_chunk
|
||||
|
|
|
@ -74,6 +74,7 @@ class GeminiDDP(ModelWrapper):
|
|||
mixed_precision: torch.dtype = torch.float16,
|
||||
process_group: Optional[ProcessGroup] = None,
|
||||
memstats: Optional[MemStats] = None, # genimi memory stats
|
||||
master_weights: bool = True,
|
||||
verbose: bool = False,
|
||||
) -> None:
|
||||
assert mixed_precision in (torch.float16, torch.bfloat16)
|
||||
|
@ -115,6 +116,9 @@ class GeminiDDP(ModelWrapper):
|
|||
self.mixed_precision = mixed_precision
|
||||
self.dp_process_group = process_group or _get_default_group()
|
||||
|
||||
self.reuse_fp16_chunk = master_weights
|
||||
self.master_weights = master_weights
|
||||
|
||||
self._logger = get_dist_logger()
|
||||
|
||||
if self.gemini_manager._premade_memstats_:
|
||||
|
@ -321,20 +325,37 @@ class GeminiDDP(ModelWrapper):
|
|||
f"Parameter `{self.param2name[p]}` failed at the gradient reduction. "
|
||||
"Some unsupported torch function is operated upon this parameter."
|
||||
)
|
||||
self.chunk_manager.trans_tensor_state(p, TensorState.READY_FOR_REDUCE)
|
||||
chunk.copy_tensor_to_chunk_slice(p, grad)
|
||||
reduced = self.chunk_manager.reduce_chunk(chunk)
|
||||
grad_chunk = chunk
|
||||
if not self.reuse_fp16_chunk:
|
||||
grad_chunk = self.chunk_manager.init_grad_chunk(chunk)
|
||||
# hold -> compute -> hold after bwd
|
||||
grad_chunk.tensor_trans_state(p, TensorState.COMPUTE)
|
||||
grad_chunk.tensor_trans_state(p, TensorState.HOLD_AFTER_BWD)
|
||||
# fp16 param chunk: hold after bwd -> ready for reduce -> hold
|
||||
chunk.tensor_trans_state(p, TensorState.READY_FOR_REDUCE)
|
||||
chunk.tensor_trans_state(p, TensorState.HOLD)
|
||||
|
||||
grad_chunk.tensor_trans_state(p, TensorState.READY_FOR_REDUCE)
|
||||
grad_chunk.copy_tensor_to_chunk_slice(p, grad, update_ptr=self.reuse_fp16_chunk)
|
||||
reduced = self.chunk_manager.reduce_chunk(grad_chunk)
|
||||
if reduced:
|
||||
if chunk.is_gathered:
|
||||
chunk.cuda_global_chunk.div_(chunk.pg_size)
|
||||
if not self.reuse_fp16_chunk:
|
||||
if chunk.keep_gathered:
|
||||
self.chunk_manager.fake_release_chunk(chunk)
|
||||
else:
|
||||
self.chunk_manager.release_chunk(chunk)
|
||||
if grad_chunk.is_gathered:
|
||||
grad_chunk.cuda_global_chunk.div_(chunk.pg_size)
|
||||
else:
|
||||
chunk.cuda_shard.div_(chunk.pg_size)
|
||||
grad_chunk.cuda_shard.div_(chunk.pg_size)
|
||||
# check overflow elements
|
||||
self.overflow_counter += chunk.has_inf_or_nan
|
||||
# record l2 norm for gradient clipping
|
||||
self.overflow_counter += grad_chunk.has_inf_or_nan
|
||||
# record l2 norm for gradient clipping. flag is bound to fp16 chunk
|
||||
if chunk.l2_norm_flag:
|
||||
chunk.set_l2_norm()
|
||||
self.chunk_manager.move_chunk(chunk, self.grads_device[p], force_copy=True)
|
||||
grad_chunk.set_l2_norm()
|
||||
self.chunk_manager.move_chunk(grad_chunk, self.grads_device[p], force_copy=True)
|
||||
if not self.master_weights:
|
||||
self.chunk_manager.move_chunk(chunk, self.grads_device[p], force_copy=True)
|
||||
return empty_grad
|
||||
|
||||
def zero_grad(self, set_to_none: bool = False) -> None:
|
||||
|
@ -344,9 +365,7 @@ class GeminiDDP(ModelWrapper):
|
|||
for tensor in chunk.get_tensors():
|
||||
self.grads_device[tensor] = device
|
||||
|
||||
def state_dict(
|
||||
self, destination=None, prefix="", keep_vars=False, only_rank_0: bool = True, dtype: torch.dtype = torch.float16
|
||||
):
|
||||
def state_dict(self, destination=None, prefix="", keep_vars=False, only_rank_0: bool = True):
|
||||
"""Returns a dictionary containing a whole state of the module.
|
||||
|
||||
Both parameters and persistent buffers (e.g. running averages) are included.
|
||||
|
@ -365,7 +384,7 @@ class GeminiDDP(ModelWrapper):
|
|||
destination = OrderedDict()
|
||||
destination._metadata = OrderedDict()
|
||||
destination._metadata[prefix[:-1]] = local_metadata = dict(version=self._version)
|
||||
self._save_to_state_dict(destination, prefix, keep_vars, only_rank_0, dtype)
|
||||
self._save_to_state_dict(destination, prefix, keep_vars, only_rank_0)
|
||||
|
||||
for hook in self._state_dict_hooks.values():
|
||||
hook_result = hook(self, destination, prefix, local_metadata)
|
||||
|
@ -373,7 +392,7 @@ class GeminiDDP(ModelWrapper):
|
|||
destination = hook_result
|
||||
return destination
|
||||
|
||||
def _get_chunk_to_save_data(self, chunk: Chunk, only_rank_0: bool, dtype: torch.dtype = torch.float16) -> Dict:
|
||||
def _get_chunk_to_save_data(self, chunk: Chunk, only_rank_0: bool) -> Dict:
|
||||
"""
|
||||
get gathered chunk content.
|
||||
|
||||
|
@ -386,9 +405,8 @@ class GeminiDDP(ModelWrapper):
|
|||
"""
|
||||
# save parameters
|
||||
chunk_to_save_data = dict()
|
||||
temp_chunk = get_temp_total_chunk_on_cuda(chunk)
|
||||
if torch.is_floating_point(temp_chunk):
|
||||
temp_chunk = temp_chunk.to(dtype)
|
||||
temp_chunk = get_temp_total_chunk_on_cuda(chunk, self.mixed_precision)
|
||||
|
||||
for tensor, tensor_info in chunk.tensors_info.items():
|
||||
record_tensor = torch.empty([0])
|
||||
record_flag = (not only_rank_0) | (dist.get_rank(chunk.torch_pg) == 0)
|
||||
|
@ -401,9 +419,7 @@ class GeminiDDP(ModelWrapper):
|
|||
del temp_chunk
|
||||
return chunk_to_save_data
|
||||
|
||||
def _get_param_to_save_data(
|
||||
self, param_list: List[torch.nn.Parameter], only_rank_0: bool, dtype: torch.dtype
|
||||
) -> Dict:
|
||||
def _get_param_to_save_data(self, param_list: List[torch.nn.Parameter], only_rank_0: bool) -> Dict:
|
||||
"""
|
||||
get param content from chunks.
|
||||
|
||||
|
@ -418,10 +434,10 @@ class GeminiDDP(ModelWrapper):
|
|||
param_to_save_data = dict()
|
||||
chunk_list = self.chunk_manager.get_chunks(param_list)
|
||||
for chunk in chunk_list:
|
||||
param_to_save_data.update(self._get_chunk_to_save_data(chunk, only_rank_0, dtype))
|
||||
param_to_save_data.update(self._get_chunk_to_save_data(chunk, only_rank_0))
|
||||
return param_to_save_data
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True, dtype=torch.float16):
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True):
|
||||
r"""Saves module state to `destination` dictionary, containing a state
|
||||
of the module, but not its descendants. This is called on every
|
||||
submodule in :meth:`~torch.nn.Module.state_dict`.
|
||||
|
@ -438,14 +454,18 @@ class GeminiDDP(ModelWrapper):
|
|||
|
||||
# get copies of fp32 parameters in CPU
|
||||
# as memory of fp16_params may be reused by grad, it's not reliable, we should use fp32_params and convert to fp16
|
||||
param_to_save_data = self._get_param_to_save_data(self.fp32_params, only_rank_0, dtype)
|
||||
params = self.fp32_params if self.reuse_fp16_chunk else self.fp16_params
|
||||
param_to_save_data = self._get_param_to_save_data(params, only_rank_0)
|
||||
# get the mapping between copies and fp16 parameters
|
||||
p_mapping = dict()
|
||||
for p, fp32_p in zip(self.fp16_params, self.fp32_params):
|
||||
name = self.param2name[p]
|
||||
assert fp32_p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name)
|
||||
record_parameter = param_to_save_data[fp32_p]
|
||||
p_mapping[p] = record_parameter
|
||||
if self.reuse_fp16_chunk:
|
||||
for p, fp32_p in zip(self.fp16_params, self.fp32_params):
|
||||
name = self.param2name[p]
|
||||
assert fp32_p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name)
|
||||
record_parameter = param_to_save_data[fp32_p]
|
||||
p_mapping[p] = record_parameter
|
||||
else:
|
||||
p_mapping = param_to_save_data
|
||||
for name, param in self.name2param.items():
|
||||
if param is not None:
|
||||
if is_ddp_ignored(param):
|
||||
|
@ -593,7 +613,7 @@ class GeminiDDP(ModelWrapper):
|
|||
elif strict:
|
||||
missing_keys.append(state_key)
|
||||
|
||||
def load_fp32_parameter(chunk_slice, data):
|
||||
def load_parameter(chunk_slice, data):
|
||||
chunk_slice.copy_(data.flatten())
|
||||
|
||||
for name, param in self.named_parameters():
|
||||
|
@ -607,14 +627,15 @@ class GeminiDDP(ModelWrapper):
|
|||
name = self.param2name[p]
|
||||
fp32_to_name[fp32_p] = name
|
||||
|
||||
chunk_list = self.chunk_manager.get_chunks(self.fp32_params)
|
||||
params_to_load = self.fp32_params if self.reuse_fp16_chunk else self.fp16_params
|
||||
chunk_list = self.chunk_manager.get_chunks(params_to_load)
|
||||
for chunk in chunk_list:
|
||||
temp_chunk = get_temp_total_chunk_on_cuda(chunk)
|
||||
temp_chunk = get_temp_total_chunk_on_cuda(chunk, self.mixed_precision)
|
||||
|
||||
for tensor, tensor_info in chunk.tensors_info.items():
|
||||
parameter_name = fp32_to_name[tensor]
|
||||
parameter_name = fp32_to_name[tensor] if self.reuse_fp16_chunk else self.param2name[tensor]
|
||||
parameter_slice = temp_chunk[tensor_info.offset : tensor_info.end]
|
||||
load(parameter_name, tensor, partial(load_fp32_parameter, parameter_slice))
|
||||
load(parameter_name, tensor, partial(load_parameter, parameter_slice))
|
||||
|
||||
if chunk.is_gathered:
|
||||
chunk.cuda_global_chunk.copy_(temp_chunk)
|
||||
|
@ -624,11 +645,11 @@ class GeminiDDP(ModelWrapper):
|
|||
chunk.cpu_shard.copy_(temp_chunk[chunk.shard_begin : chunk.shard_end])
|
||||
|
||||
del temp_chunk
|
||||
|
||||
for chunk_32 in chunk_list:
|
||||
chunk_16 = chunk_32.paired_chunk
|
||||
assert chunk_16 is not None
|
||||
chunk_16.payload.copy_(chunk_32.payload)
|
||||
if self.reuse_fp16_chunk:
|
||||
for chunk_32 in chunk_list:
|
||||
chunk_16 = chunk_32.paired_chunk
|
||||
assert chunk_16 is not None
|
||||
chunk_16.payload.copy_(chunk_32.payload)
|
||||
|
||||
for name, buf in persistent_buffers.items():
|
||||
if buf is not None:
|
||||
|
@ -668,12 +689,9 @@ class GeminiDDP(ModelWrapper):
|
|||
p.data = p.data.to(device=get_current_device(), dtype=self.mixed_precision)
|
||||
continue
|
||||
|
||||
# create a fp32 parameter
|
||||
fp32_p = p.data.float()
|
||||
# create a fp16 parameter
|
||||
p.data = p.data.to(self.mixed_precision)
|
||||
|
||||
# register the fp16 parameter and fp32 parameter in the chunk manager
|
||||
# register the fp16 parameter
|
||||
self.chunk_manager.register_tensor(
|
||||
tensor=p,
|
||||
group_type="fp16_param",
|
||||
|
@ -682,22 +700,27 @@ class GeminiDDP(ModelWrapper):
|
|||
cpu_offload=cpu_offload,
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
self.chunk_manager.register_tensor(
|
||||
tensor=fp32_p,
|
||||
group_type="fp32_param",
|
||||
config_key=dp_world_size,
|
||||
process_group=self.dp_process_group,
|
||||
cpu_offload=cpu_offload,
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
|
||||
self.fp16_params.append(p)
|
||||
self.fp32_params.append(fp32_p)
|
||||
|
||||
if self.master_weights:
|
||||
# create a fp32 parameter
|
||||
fp32_p = p.data.float()
|
||||
self.chunk_manager.register_tensor(
|
||||
tensor=fp32_p,
|
||||
group_type="fp32_param",
|
||||
config_key=dp_world_size,
|
||||
process_group=self.dp_process_group,
|
||||
cpu_offload=cpu_offload,
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
self.fp32_params.append(fp32_p)
|
||||
|
||||
self.chunk_manager.close_all_groups()
|
||||
|
||||
self.gemini_manager.setup_grads_device(self.fp16_params, self.grads_device)
|
||||
|
||||
# move master weights to corresponding device and setup paired chunks
|
||||
# if no master weights, fp32_params should be empty and this loop will be skipped
|
||||
for p, fp32_p in zip(self.fp16_params, self.fp32_params):
|
||||
chunk_16 = self.chunk_manager.get_chunk(p)
|
||||
chunk_32 = self.chunk_manager.get_chunk(fp32_p)
|
||||
|
@ -734,7 +757,6 @@ class GeminiDDP(ModelWrapper):
|
|||
keep_vars: bool = False,
|
||||
max_shard_size: int = 1024,
|
||||
only_rank_0: bool = True,
|
||||
dtype: torch.dtype = torch.float16,
|
||||
) -> Iterator[Tuple[OrderedDict, int]]:
|
||||
"""Returns dictionaries containing a whole state of the module one by one. The max size of dictionary shard is specified by ``max_shard_size``.
|
||||
|
||||
|
@ -769,11 +791,11 @@ class GeminiDDP(ModelWrapper):
|
|||
gathered_param = param if keep_vars else param.detach()
|
||||
else:
|
||||
# as memory of fp16 param may be reused, we should use fp32 param and then convert to fp16
|
||||
fp32_param = fp16_to_fp32[param]
|
||||
if fp32_param not in gathered_param_buffer:
|
||||
chunk = self.chunk_manager.get_chunk(fp32_param)
|
||||
gathered_param_buffer.update(self._get_chunk_to_save_data(chunk, only_rank_0, dtype))
|
||||
gathered_param = gathered_param_buffer.pop(fp32_param)
|
||||
param_to_save = fp16_to_fp32[param] if self.reuse_fp16_chunk else param
|
||||
if param_to_save not in gathered_param_buffer:
|
||||
chunk = self.chunk_manager.get_chunk(param_to_save)
|
||||
gathered_param_buffer.update(self._get_chunk_to_save_data(chunk, only_rank_0))
|
||||
gathered_param = gathered_param_buffer.pop(param_to_save)
|
||||
|
||||
block, block_size = sharder.append_param(prefix + name, gathered_param)
|
||||
if block is not None:
|
||||
|
|
|
@ -105,7 +105,7 @@ class GeminiOptimizer(OptimizerWrapper):
|
|||
self.gemini_manager = module.gemini_manager
|
||||
self.chunk_manager: ChunkManager = self.gemini_manager.chunk_manager
|
||||
self.param_to_range: Dict[Parameter, Tuple[int, int]] = dict()
|
||||
self.param_to_chunk32: Dict[Parameter, Chunk] = dict()
|
||||
self.param_to_chunk16: Dict[Parameter, Chunk] = dict()
|
||||
self.chunk16_set: Set[Chunk] = set()
|
||||
self.clipping_flag = max_norm > 0.0
|
||||
self.max_norm = max_norm
|
||||
|
@ -130,7 +130,7 @@ class GeminiOptimizer(OptimizerWrapper):
|
|||
else:
|
||||
ddp_param_list.append(param)
|
||||
|
||||
for p, fp32_p in zip(ddp_param_list, module.fp32_params):
|
||||
for p in ddp_param_list:
|
||||
chunk_16 = self.chunk_manager.get_chunk(p)
|
||||
if chunk_16 not in self.chunk16_set:
|
||||
chunk_16.l2_norm_flag = self.clipping_flag
|
||||
|
@ -174,13 +174,15 @@ class GeminiOptimizer(OptimizerWrapper):
|
|||
def _set_grad_ptr(self):
|
||||
for group in self.param_groups:
|
||||
for fake_param in group["params"]:
|
||||
chunk32 = self.param_to_chunk32[fake_param]
|
||||
chunk16 = self.param_to_chunk16[fake_param]
|
||||
begin, end = self.param_to_range[fake_param]
|
||||
chunk16 = chunk32.paired_chunk
|
||||
|
||||
fake_param.data = chunk16.payload[begin:end]
|
||||
grad_chunk16 = chunk16 if self.module.reuse_fp16_chunk else chunk16.grad_chunk
|
||||
fake_param.data = grad_chunk16.payload[begin:end]
|
||||
fake_param.grad = fake_param.data
|
||||
fake_param.data = chunk32.payload[begin:end]
|
||||
|
||||
to_update_chunk = chunk16.paired_chunk if self.module.master_weights else chunk16
|
||||
fake_param.data = to_update_chunk.payload[begin:end]
|
||||
|
||||
def _update_fp16_params(self):
|
||||
none_tensor = torch.empty([0])
|
||||
|
@ -194,23 +196,25 @@ class GeminiOptimizer(OptimizerWrapper):
|
|||
|
||||
def _clear_global_norm(self) -> None:
|
||||
for c16 in self.chunk16_set:
|
||||
c16.l2_norm = None
|
||||
grad_chunk = c16 if self.module.reuse_fp16_chunk else c16.grad_chunk
|
||||
grad_chunk.l2_norm = None
|
||||
|
||||
def _calc_global_norm(self) -> float:
|
||||
norm_sqr: float = 0.0
|
||||
group_to_norm = dict()
|
||||
for c16 in self.chunk16_set:
|
||||
assert c16.l2_norm is not None
|
||||
grad_chunk = c16 if self.module.reuse_fp16_chunk else c16.grad_chunk
|
||||
assert grad_chunk.l2_norm is not None
|
||||
|
||||
if c16.is_gathered:
|
||||
norm_sqr += c16.l2_norm
|
||||
if grad_chunk.is_gathered:
|
||||
norm_sqr += grad_chunk.l2_norm
|
||||
else:
|
||||
# this chunk is sharded, use communication to collect total norm
|
||||
if c16.torch_pg not in group_to_norm:
|
||||
group_to_norm[c16.torch_pg] = 0.0
|
||||
group_to_norm[c16.torch_pg] += c16.l2_norm
|
||||
if grad_chunk.torch_pg not in group_to_norm:
|
||||
group_to_norm[grad_chunk.torch_pg] = 0.0
|
||||
group_to_norm[grad_chunk.torch_pg] += grad_chunk.l2_norm
|
||||
|
||||
c16.l2_norm = None # clear l2 norm
|
||||
grad_chunk.l2_norm = None # clear l2 norm
|
||||
|
||||
comm_buffer = torch.zeros(1, dtype=torch.float, device=get_current_device())
|
||||
for group, part_norm in group_to_norm.items():
|
||||
|
@ -237,7 +241,8 @@ class GeminiOptimizer(OptimizerWrapper):
|
|||
return self.optim.zero_grad(set_to_none=True)
|
||||
|
||||
def step(self, *args, **kwargs):
|
||||
self._maybe_move_fp32_params()
|
||||
if self.module.master_weights:
|
||||
self._maybe_move_fp32_params()
|
||||
self._set_grad_ptr()
|
||||
|
||||
if self.mix_precision_mixin.should_skip_step():
|
||||
|
@ -245,7 +250,8 @@ class GeminiOptimizer(OptimizerWrapper):
|
|||
self._logger.info(f"Found overflow. Skip step")
|
||||
self._clear_global_norm() # clear recorded norm
|
||||
self.zero_grad() # reset all gradients
|
||||
self._update_fp16_params()
|
||||
if self.module.reuse_fp16_chunk:
|
||||
self._update_fp16_params()
|
||||
return
|
||||
|
||||
# get combined scale. combined scale = loss scale * clipping norm
|
||||
|
@ -255,7 +261,8 @@ class GeminiOptimizer(OptimizerWrapper):
|
|||
ret = self.optim.step(div_scale=combined_scale, *args, **kwargs)
|
||||
self._register_states()
|
||||
self.zero_grad()
|
||||
self._update_fp16_params()
|
||||
if self.module.master_weights:
|
||||
self._update_fp16_params()
|
||||
return ret
|
||||
|
||||
def clip_grad_norm(self, model: torch.nn.Module, max_norm: float, norm_type: float = 2.0):
|
||||
|
@ -282,8 +289,8 @@ class GeminiOptimizer(OptimizerWrapper):
|
|||
|
||||
for group in self.param_groups:
|
||||
for fake_param in group["params"]:
|
||||
chunk32 = self.param_to_chunk32[fake_param]
|
||||
chunk16 = chunk32.paired_chunk
|
||||
chunk16 = self.param_to_chunk16[fake_param]
|
||||
chunk32 = chunk16.paired_chunk
|
||||
|
||||
if chunk32.device_type == "cuda":
|
||||
continue
|
||||
|
@ -297,7 +304,8 @@ class GeminiOptimizer(OptimizerWrapper):
|
|||
|
||||
for group in self.param_groups:
|
||||
for fake_param in group["params"]:
|
||||
chunk32 = self.param_to_chunk32[fake_param]
|
||||
chunk16 = self.param_to_chunk16[fake_param]
|
||||
chunk32 = chunk16.paired_chunk
|
||||
if chunk32.device_type == "cuda":
|
||||
state = self.optim.state[fake_param]
|
||||
for k, v in state.items():
|
||||
|
@ -341,7 +349,7 @@ class GeminiOptimizer(OptimizerWrapper):
|
|||
continue
|
||||
grad_device = self.module.grads_device[param]
|
||||
fake_param = torch.nn.Parameter(torch.empty([0], device=grad_device))
|
||||
self.param_to_chunk32[fake_param] = chunk16.paired_chunk
|
||||
self.param_to_chunk16[fake_param] = chunk16
|
||||
self.param_to_range[fake_param] = range_pair
|
||||
self.id_to_fake_params[param_id] = fake_param
|
||||
fake_params_list.append(fake_param)
|
||||
|
@ -366,7 +374,7 @@ class GeminiOptimizer(OptimizerWrapper):
|
|||
if param_id not in self.id_to_fake_params:
|
||||
return -1, -1, -1
|
||||
fake_param = self.id_to_fake_params[param_id]
|
||||
chunk = self.param_to_chunk32[fake_param].paired_chunk
|
||||
chunk = self.param_to_chunk16[fake_param]
|
||||
param = self.id_to_real_params[param_id]
|
||||
param_info = chunk.tensors_info[param]
|
||||
|
||||
|
|
|
@ -11,7 +11,7 @@ from colossalai.utils import get_current_device
|
|||
from .chunk import Chunk
|
||||
|
||||
|
||||
def get_temp_total_chunk_on_cuda(chunk: Chunk):
|
||||
def get_temp_total_chunk_on_cuda(chunk: Chunk, dtype: torch.dtype):
|
||||
if chunk.is_gathered:
|
||||
return chunk.cuda_global_chunk
|
||||
|
||||
|
@ -20,7 +20,9 @@ def get_temp_total_chunk_on_cuda(chunk: Chunk):
|
|||
else:
|
||||
shard_temp = chunk.cpu_shard.to(get_current_device())
|
||||
|
||||
total_temp = torch.zeros(chunk.chunk_size, dtype=chunk.dtype, device=get_current_device())
|
||||
shard_temp = shard_temp.to(dtype)
|
||||
|
||||
total_temp = torch.zeros(chunk.chunk_size, dtype=dtype, device=get_current_device())
|
||||
gather_list = list(torch.chunk(input=total_temp, chunks=chunk.pg_size, dim=0))
|
||||
dist.all_gather(tensor_list=gather_list, tensor=shard_temp, group=chunk.torch_pg)
|
||||
|
||||
|
|
|
@ -58,9 +58,7 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b
|
|||
dist.barrier()
|
||||
|
||||
new_bert_model = BertForSequenceClassification.from_pretrained(pretrained_path)
|
||||
check_state_dict_equal(
|
||||
bert_model.state_dict(only_rank_0=False, dtype=torch.float32), new_bert_model.state_dict(), False
|
||||
)
|
||||
check_state_dict_equal(bert_model.state_dict(only_rank_0=False), new_bert_model.state_dict(), False)
|
||||
|
||||
|
||||
@clear_cache_before_run()
|
||||
|
@ -100,7 +98,9 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha
|
|||
dist.barrier()
|
||||
|
||||
booster.load_model(new_model, model_ckpt_path)
|
||||
check_state_dict_equal(model.state_dict(only_rank_0=False), new_model.state_dict(only_rank_0=False), False)
|
||||
check_state_dict_equal(
|
||||
model.state_dict(only_rank_0=False), new_model.state_dict(only_rank_0=False), False, ignore_dtype=True
|
||||
)
|
||||
|
||||
booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
|
||||
check_state_dict_equal(
|
||||
|
@ -136,7 +136,7 @@ def exam_lazy_from_pretrained():
|
|||
booster.save_model(model, save_path, shard=False)
|
||||
dist.barrier()
|
||||
state_dict = torch.load(save_path, map_location="cpu")
|
||||
check_state_dict_equal(state_dict, orig_state_dict, False)
|
||||
check_state_dict_equal(state_dict, orig_state_dict, False, ignore_dtype=True)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
|
|
|
@ -60,9 +60,10 @@ def exam_torch_load_from_gemini(shard: bool, model_name: str):
|
|||
|
||||
# Add prefix to get aligned with pytorch parameter names.
|
||||
check_state_dict_equal(
|
||||
model.state_dict(only_rank_0=False, prefix="module.module.", dtype=torch.float32),
|
||||
model.state_dict(only_rank_0=False, prefix="module.module."),
|
||||
new_model.state_dict(),
|
||||
False,
|
||||
ignore_dtype=True,
|
||||
)
|
||||
|
||||
new_booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
|
||||
|
@ -125,9 +126,10 @@ def exam_gemini_load_from_torch(shard: bool, model_name: str):
|
|||
|
||||
# Add prefix to get aligned with pytorch parameter names.
|
||||
check_state_dict_equal(
|
||||
new_model.state_dict(only_rank_0=False, prefix="module.module.", dtype=torch.float32),
|
||||
new_model.state_dict(only_rank_0=False, prefix="module.module."),
|
||||
model.state_dict(),
|
||||
False,
|
||||
ignore_dtype=True,
|
||||
)
|
||||
|
||||
new_booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
|
||||
|
|
|
@ -27,6 +27,8 @@ def check_grad(model: GeminiDDP, torch_model: torch.nn.Module):
|
|||
chunk_manager = model.chunk_manager
|
||||
param_list = [p for p in model.parameters()]
|
||||
chunk_list = chunk_manager.get_chunks(param_list)
|
||||
if not model.reuse_fp16_chunk:
|
||||
chunk_list = [chunk.grad_chunk for chunk in chunk_list]
|
||||
for chunk in chunk_list:
|
||||
chunk_manager.access_chunk(chunk)
|
||||
|
||||
|
@ -36,13 +38,15 @@ def check_grad(model: GeminiDDP, torch_model: torch.nn.Module):
|
|||
|
||||
@parameterize("placement_config", PLACEMENT_CONFIGS)
|
||||
@parameterize("keep_gather", [False, True])
|
||||
@parameterize("model_name", ["gpt2", "bert", "albert"])
|
||||
@parameterize("model_name", ["gpt2", "bert"])
|
||||
@parameterize("use_grad_checkpoint", [False, True])
|
||||
@parameterize("master_weights", [False, True])
|
||||
def exam_gpt_fwd_bwd(
|
||||
placement_config,
|
||||
keep_gather,
|
||||
model_name: str,
|
||||
use_grad_checkpoint: bool = False,
|
||||
master_weights: bool = True,
|
||||
):
|
||||
init_device = get_current_device()
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
|
@ -60,12 +64,14 @@ def exam_gpt_fwd_bwd(
|
|||
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
|
||||
config_dict[world_size]["chunk_size"] = 5000
|
||||
config_dict[world_size]["keep_gathered"] = keep_gather
|
||||
model = GeminiDDP(model, config_dict, init_device, pin_memory=True, **placement_config)
|
||||
model = GeminiDDP(
|
||||
model, config_dict, init_device, pin_memory=True, **placement_config, master_weights=master_weights
|
||||
)
|
||||
optimizer = HybridAdam(model.parameters(), lr=1e-3)
|
||||
zero_optim = GeminiOptimizer(optimizer, model, initial_scale=1)
|
||||
|
||||
rank = dist.get_rank()
|
||||
amp_config = dict(opt_level="O2", keep_batchnorm_fp32=False, loss_scale=1)
|
||||
amp_config = dict(opt_level="O2", keep_batchnorm_fp32=False, loss_scale=1, master_weights=master_weights)
|
||||
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
|
||||
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
|
||||
torch_model = DDP(torch_model, device_ids=[rank])
|
||||
|
@ -106,4 +112,4 @@ def test_gpt(world_size):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_gpt(4)
|
||||
test_gpt(1)
|
||||
|
|
|
@ -78,7 +78,11 @@ def exam_grad_clipping(placement_config, model_name: str):
|
|||
init_device = None
|
||||
|
||||
model = GeminiDDP(
|
||||
model, chunk_config_dict=config_dict, chunk_init_device=init_device, pin_memory=True, **placement_config
|
||||
model,
|
||||
chunk_config_dict=config_dict,
|
||||
chunk_init_device=init_device,
|
||||
pin_memory=True,
|
||||
**placement_config,
|
||||
)
|
||||
|
||||
optimizer = HybridAdam(model.parameters(), lr=1e-3)
|
||||
|
|
|
@ -44,7 +44,7 @@ BF16_IGNORED_KEYS = [
|
|||
|
||||
|
||||
def check_param(model: GeminiDDP, torch_model: torch.nn.Module, dtype: torch.dtype):
|
||||
zero_dict = model.state_dict(only_rank_0=False, dtype=dtype)
|
||||
zero_dict = model.state_dict(only_rank_0=False)
|
||||
torch_dict = torch_model.state_dict()
|
||||
|
||||
for key, value in torch_dict.items():
|
||||
|
|
|
@ -27,7 +27,8 @@ def ignore_the_first_parameter(model: torch.nn.Module):
|
|||
@parameterize("placement_config", PLACEMENT_CONFIGS)
|
||||
@parameterize("keep_gathered", [True, False])
|
||||
@parameterize("model_name", ["gpt2", "bert"])
|
||||
def exam_state_dict(placement_config, keep_gathered, model_name: str):
|
||||
@parameterize("master_weights", [False, True])
|
||||
def exam_state_dict(placement_config, keep_gathered, model_name: str, master_weights: bool):
|
||||
set_seed(431)
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
|
@ -42,7 +43,7 @@ def exam_state_dict(placement_config, keep_gathered, model_name: str):
|
|||
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
|
||||
config_dict[world_size]["chunk_size"] = 5000
|
||||
config_dict[world_size]["keep_gathered"] = keep_gathered
|
||||
model = GeminiDDP(model, config_dict, **placement_config, pin_memory=True)
|
||||
model = GeminiDDP(model, config_dict, **placement_config, pin_memory=True, master_weights=master_weights)
|
||||
model.train()
|
||||
|
||||
zero_dict = model.state_dict(only_rank_0=False)
|
||||
|
@ -57,7 +58,8 @@ def exam_state_dict(placement_config, keep_gathered, model_name: str):
|
|||
@parameterize("placement_config", PLACEMENT_CONFIGS)
|
||||
@parameterize("keep_gathered", [True, False])
|
||||
@parameterize("model_name", ["gpt2", "bert"])
|
||||
def exam_load_state_dict(placement_config, keep_gathered, model_name: str):
|
||||
@parameterize("master_weights", [False, True])
|
||||
def exam_load_state_dict(placement_config, keep_gathered, model_name: str, master_weights: bool):
|
||||
set_seed(431)
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
|
@ -72,7 +74,7 @@ def exam_load_state_dict(placement_config, keep_gathered, model_name: str):
|
|||
config_dict[world_size]["chunk_size"] = 5000
|
||||
config_dict[world_size]["keep_gathered"] = keep_gathered
|
||||
|
||||
model = GeminiDDP(model, config_dict, **placement_config, pin_memory=True)
|
||||
model = GeminiDDP(model, config_dict, **placement_config, pin_memory=True, master_weights=master_weights)
|
||||
|
||||
torch_dict = torch_model.state_dict()
|
||||
model.load_state_dict(torch_dict, strict=False)
|
||||
|
@ -86,7 +88,8 @@ def exam_load_state_dict(placement_config, keep_gathered, model_name: str):
|
|||
|
||||
@parameterize("placement_config", PLACEMENT_CONFIGS)
|
||||
@parameterize("model_name", ["gpt2", "bert"])
|
||||
def exam_state_dict_shard(placement_config, model_name: str):
|
||||
@parameterize("master_weights", [False, True])
|
||||
def exam_state_dict_shard(placement_config, model_name: str, master_weights: bool):
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
|
||||
|
@ -95,7 +98,7 @@ def exam_state_dict_shard(placement_config, model_name: str):
|
|||
model_size = sum(p.numel() * p.element_size() for p in model.parameters()) / 1024**2
|
||||
|
||||
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
|
||||
model = GeminiDDP(model, config_dict, **placement_config)
|
||||
model = GeminiDDP(model, config_dict, **placement_config, master_weights=master_weights)
|
||||
model.train()
|
||||
|
||||
zero_dict = model.state_dict(only_rank_0=False)
|
||||
|
|
Loading…
Reference in New Issue