mirror of https://github.com/hpcaitech/ColossalAI
[feature] support no master weights option for low level zero plugin (#4816)
* [feature] support no master weights for low level zero plugin * [feature] support no master weights for low level zero plugin, remove data copy when no master weights * remove data copy and typecasting when no master weights * not load weights to cpu when using no master weights * fix grad: use fp16 grad when no master weights * only do not update working param when no master weights * fix: only do not update working param when no master weights * fix: passing params in dict format in hybrid plugin * fix: remove extra params (tp_process_group) in hybrid_parallel_pluginpull/4314/head^2
parent
77a9328304
commit
a0684e7bd6
|
@ -464,23 +464,23 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
|||
if use_pipeline:
|
||||
init_pipeline_optimizer(optimizer, model)
|
||||
super().__init__(
|
||||
optimizer,
|
||||
initial_scale,
|
||||
min_scale,
|
||||
growth_factor,
|
||||
backoff_factor,
|
||||
growth_interval,
|
||||
hysteresis,
|
||||
max_scale,
|
||||
clip_grad_norm,
|
||||
verbose,
|
||||
reduce_bucket_size,
|
||||
communication_dtype,
|
||||
overlap_communication,
|
||||
partition_grad,
|
||||
cpu_offload,
|
||||
dp_process_group,
|
||||
forced_dtype,
|
||||
optimizer=optimizer,
|
||||
initial_scale=initial_scale,
|
||||
min_scale=min_scale,
|
||||
growth_factor=growth_factor,
|
||||
backoff_factor=backoff_factor,
|
||||
growth_interval=growth_interval,
|
||||
hysteresis=hysteresis,
|
||||
max_scale=max_scale,
|
||||
clip_grad_norm=clip_grad_norm,
|
||||
verbose=verbose,
|
||||
reduce_bucket_size=reduce_bucket_size,
|
||||
communication_dtype=communication_dtype,
|
||||
overlap_communication=overlap_communication,
|
||||
partition_grad=partition_grad,
|
||||
cpu_offload=cpu_offload,
|
||||
dp_process_group=dp_process_group,
|
||||
forced_dtype=forced_dtype,
|
||||
)
|
||||
|
||||
def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> float:
|
||||
|
|
|
@ -262,6 +262,7 @@ class LowLevelZeroPlugin(DPPluginBase):
|
|||
communication_dtype: Optional[torch.dtype] = None,
|
||||
overlap_communication: bool = True,
|
||||
cpu_offload: bool = False,
|
||||
master_weights: bool = True,
|
||||
verbose: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
@ -272,18 +273,19 @@ class LowLevelZeroPlugin(DPPluginBase):
|
|||
self.precision = precision
|
||||
self.zero_optim_kwargs = dict(
|
||||
initial_scale=initial_scale,
|
||||
min_scale=min_scale,
|
||||
growth_factor=growth_factor,
|
||||
backoff_factor=backoff_factor,
|
||||
growth_interval=growth_interval,
|
||||
hysteresis=hysteresis,
|
||||
min_scale=min_scale,
|
||||
max_scale=max_scale,
|
||||
clip_grad_norm=max_norm,
|
||||
reduce_bucket_size=reduce_bucket_size_in_m * 1024 * 1024,
|
||||
communication_dtype=communication_dtype,
|
||||
overlap_communication=overlap_communication,
|
||||
cpu_offload=cpu_offload,
|
||||
partition_grad=(stage == 2),
|
||||
cpu_offload=cpu_offload,
|
||||
master_weights=master_weights,
|
||||
)
|
||||
self.verbose = verbose
|
||||
|
||||
|
|
|
@ -75,6 +75,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
cpu_offload: bool = False, # cpu offload
|
||||
dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm
|
||||
forced_dtype: Optional[torch.dtype] = None,
|
||||
master_weights: bool = True, # master weights
|
||||
):
|
||||
super(LowLevelZeroOptimizer, self).__init__(optim=optimizer)
|
||||
self._dtype = self.optim.param_groups[0]["params"][0].dtype
|
||||
|
@ -106,6 +107,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
# gradient clipping
|
||||
self._clip_grad_norm = clip_grad_norm
|
||||
|
||||
# master weights copy
|
||||
self._master_weights = master_weights
|
||||
|
||||
if forced_dtype:
|
||||
for group in self.optim.param_groups:
|
||||
group_params = group["params"]
|
||||
|
@ -135,7 +139,6 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
self._working_param_groups[group_id] = group_params
|
||||
|
||||
master_param_current_rank = self._create_master_param_current_rank(group_params)
|
||||
|
||||
self._master_param_groups_of_current_rank[group_id] = master_param_current_rank
|
||||
|
||||
# need to replace the params in the `params` field in the optimizer
|
||||
|
@ -200,11 +203,18 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
with torch.no_grad():
|
||||
if padding_size > 0:
|
||||
padding_param = torch.nn.functional.pad(param.data.view(-1), [0, padding_size])
|
||||
# reset working params' ptr when no master weights
|
||||
if self._master_weights == False:
|
||||
param.data = padding_param[: param.numel()].view(param.shape)
|
||||
else:
|
||||
padding_param = param.data.view(-1)
|
||||
splited_params = padding_param.split(padding_param.numel() // self._world_size)
|
||||
|
||||
splited_param_current_rank = splited_params[self._local_rank].detach().float().to(device)
|
||||
# use fp32 when master_weights is True
|
||||
if self._master_weights is True:
|
||||
splited_param_current_rank = splited_params[self._local_rank].detach().float().to(device)
|
||||
else:
|
||||
splited_param_current_rank = splited_params[self._local_rank]
|
||||
params_current_rank.append(splited_param_current_rank)
|
||||
self._param_store.link_master_and_working_param(splited_param_current_rank, param)
|
||||
|
||||
|
@ -402,9 +412,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
# and should not be updated
|
||||
real_working_params = dict()
|
||||
real_master_params = dict()
|
||||
|
||||
grad_index = 0 if self._partition_grads else self._local_rank
|
||||
|
||||
for group_id in range(self.num_param_groups):
|
||||
master_params = self._master_param_groups_of_current_rank[group_id]
|
||||
real_working_params[group_id] = []
|
||||
|
@ -417,7 +425,12 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
grads = self._grad_store.get_partitioned_gradients_by_param_id(group_id, id(working_param))
|
||||
if len(grads) > 0:
|
||||
real_working_params[group_id].append(working_param)
|
||||
grad = grads[grad_index].to(splited_param.dtype).to(splited_param.device)
|
||||
# no need to copy fp32 grad if master_weights is False
|
||||
grad = (
|
||||
grads[grad_index].to(splited_param.dtype).to(splited_param.device)
|
||||
if self._master_weights
|
||||
else grads[grad_index]
|
||||
)
|
||||
splited_param.grad = grad
|
||||
grad_partition_groups.append(grad)
|
||||
real_master_params[group_id].append(splited_param)
|
||||
|
@ -445,17 +458,16 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
release_param_grad(self._master_param_groups_of_current_rank[group_id])
|
||||
|
||||
# update working partition updated by the current rank
|
||||
dtype = real_working_params[0][0].dtype
|
||||
# dtype = real_working_params[0][0].dtype
|
||||
for group_id in range(self.num_param_groups):
|
||||
master_working_param = self.optim.param_groups[group_id]["params"]
|
||||
for idx, splited_param in enumerate(master_working_param):
|
||||
working_param = real_working_params[group_id][idx]
|
||||
all_splited_param = [
|
||||
torch.zeros(splited_param.shape, device="cuda", dtype=dtype) for _ in range(self._world_size)
|
||||
torch.zeros(splited_param.shape, device="cuda", dtype=self._dtype) for _ in range(self._world_size)
|
||||
]
|
||||
dist.all_gather(all_splited_param, splited_param.cuda().to(dtype), group=self.dp_pg)
|
||||
dist.all_gather(all_splited_param, splited_param.cuda().to(self._dtype), group=self.dp_pg)
|
||||
working_param.data.copy_(flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param))
|
||||
|
||||
self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id]
|
||||
|
||||
def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> float:
|
||||
|
|
Loading…
Reference in New Issue