|
|
|
@ -3,7 +3,6 @@
|
|
|
|
|
|
|
|
|
|
import math
|
|
|
|
|
from functools import partial
|
|
|
|
|
from itertools import product
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
import torch.distributed as dist
|
|
|
|
@ -11,7 +10,6 @@ from torch.optim import Optimizer
|
|
|
|
|
|
|
|
|
|
from internlm.core.context import Config, ParallelMode
|
|
|
|
|
from internlm.core.context import global_context as gpc
|
|
|
|
|
from internlm.model.utils import is_moe_param
|
|
|
|
|
from internlm.monitor import send_alert_message
|
|
|
|
|
from internlm.solver.optimizer.store import (
|
|
|
|
|
BucketStore,
|
|
|
|
@ -117,16 +115,15 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|
|
|
|
|
|
|
|
|
self._dtype = self.optim.param_groups[0]["params"][0].dtype
|
|
|
|
|
self._cpu_offload = cpu_offload
|
|
|
|
|
self._zero_local_rank = gpc.get_local_rank(ParallelMode.ZERO1)
|
|
|
|
|
self._zero_world_size = gpc.get_world_size(ParallelMode.ZERO1)
|
|
|
|
|
self._broadcast_parallel_mode = ParallelMode.ZERO1
|
|
|
|
|
self._zero_local_rank = []
|
|
|
|
|
self._zero_world_size = []
|
|
|
|
|
self._broadcast_parallel_mode = []
|
|
|
|
|
|
|
|
|
|
# ParameterStore will manage the tensor buffers used for zero
|
|
|
|
|
# it will not manage the tensors used by mixed precision training
|
|
|
|
|
self._param_store = ParameterStore(ParallelMode.ZERO1)
|
|
|
|
|
self._grad_store = GradientStore(ParallelMode.DATA)
|
|
|
|
|
self._non_moe_bucket_store = BucketStore(ParallelMode.DATA)
|
|
|
|
|
self._moe_bucket_store = BucketStore(ParallelMode.EXPERT_DATA)
|
|
|
|
|
self._bucket_store = []
|
|
|
|
|
self._bucket_in_progress = []
|
|
|
|
|
|
|
|
|
|
# fp16 and fp32 params for mixed precision training
|
|
|
|
@ -164,7 +161,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|
|
|
|
f"gpus-{gpc.get_world_size(ParallelMode.GLOBAL)}_"
|
|
|
|
|
+ f"pp-{gpc.get_local_rank(ParallelMode.PIPELINE)}_"
|
|
|
|
|
+ f"tp-{gpc.get_local_rank(ParallelMode.TENSOR)}_"
|
|
|
|
|
+ f"zo-{self._zero_local_rank}.pt"
|
|
|
|
|
+ f"zo-{gpc.get_local_rank(ParallelMode.ZERO1)}.pt"
|
|
|
|
|
)
|
|
|
|
|
self.params_per_rank_id_dict = []
|
|
|
|
|
self._param_bcast_sync_handler = param_bcast_sync_handler
|
|
|
|
@ -180,10 +177,23 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|
|
|
|
# add the fp16 params to fp16_param_groups for bookkeeping
|
|
|
|
|
self._fp16_param_groups[group_id] = group_params
|
|
|
|
|
|
|
|
|
|
# to find real zero mode. if zero is not used, set all param group as ParallelMode.ZERO1
|
|
|
|
|
# if zero is used, expert dp group will use ParallelMode.EXPERT_DATA as the real zero mode
|
|
|
|
|
zero_mode = (
|
|
|
|
|
ParallelMode.ZERO1
|
|
|
|
|
if param_group["dp_mode"] == gpc.get_world_size(ParallelMode.ZERO1) == 1 or ParallelMode.DATA
|
|
|
|
|
else ParallelMode.EXPERT_DATA
|
|
|
|
|
)
|
|
|
|
|
self._zero_local_rank.append(gpc.get_local_rank(zero_mode))
|
|
|
|
|
self._zero_world_size.append(gpc.get_world_size(zero_mode))
|
|
|
|
|
# TODO _broadcast_parallel_mode is not only used in broadcast, maybe can change its name
|
|
|
|
|
self._broadcast_parallel_mode.append(zero_mode)
|
|
|
|
|
self._bucket_store.append(BucketStore(group_id, param_group["dp_mode"]))
|
|
|
|
|
|
|
|
|
|
# assign parameters to ranks the params in the list are sorted
|
|
|
|
|
params_per_rank, no_params_ranks = self._partition_param_list(param_group)
|
|
|
|
|
params_per_rank, no_params_ranks = self._partition_param_list(group_id, param_group)
|
|
|
|
|
self.param_group_no_params_ranks.append(no_params_ranks)
|
|
|
|
|
self.param_group_has_params.append(self._zero_local_rank not in no_params_ranks)
|
|
|
|
|
self.param_group_has_params.append(self._zero_local_rank[group_id] not in no_params_ranks)
|
|
|
|
|
|
|
|
|
|
# store the mapping between param to rank each param should belong to only one rank.
|
|
|
|
|
# we can skip the moe param and do not keep them in _param_store to save memory
|
|
|
|
@ -202,7 +212,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|
|
|
|
param.data = param.data.cpu()
|
|
|
|
|
|
|
|
|
|
# flatten the reordered tensors
|
|
|
|
|
for rank in range(self._zero_world_size):
|
|
|
|
|
for rank in range(self._zero_world_size[group_id]):
|
|
|
|
|
# No flat fp16 buffer is allocated if the process has no parameters.
|
|
|
|
|
if rank not in self.param_group_no_params_ranks[group_id]:
|
|
|
|
|
tensor_list = self._param_store.get_fp16_params_by_rank_group(rank, group_id)
|
|
|
|
@ -216,7 +226,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|
|
|
|
# No flat fp32 buffer is allocated if the process has no parameters.
|
|
|
|
|
if self.param_group_has_params[group_id]:
|
|
|
|
|
fp16_flat_current_rank = self._param_store.get_flat_fp16_param_by_rank_group(
|
|
|
|
|
self._zero_local_rank, group_id
|
|
|
|
|
self._zero_local_rank[group_id], group_id
|
|
|
|
|
)
|
|
|
|
|
fp32_flat_current_rank = fp16_flat_current_rank.float()
|
|
|
|
|
device = "cpu" if self._cpu_offload else get_current_device()
|
|
|
|
@ -265,44 +275,36 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|
|
|
|
def num_param_groups(self):
|
|
|
|
|
return len(self._fp16_param_groups)
|
|
|
|
|
|
|
|
|
|
def _partition_param_list(self, param_group):
|
|
|
|
|
def _partition_param_list(self, group_id, param_group):
|
|
|
|
|
no_params_ranks = []
|
|
|
|
|
params_per_rank = [[] for _ in range(self._zero_world_size)]
|
|
|
|
|
numel_per_rank = [0 for _ in range(self._zero_world_size)]
|
|
|
|
|
self.params_per_rank_id_dict.append([[] for _ in range(self._zero_world_size)])
|
|
|
|
|
params_per_rank = [[] for _ in range(self._zero_world_size[group_id])]
|
|
|
|
|
numel_per_rank = [0 for _ in range(self._zero_world_size[group_id])]
|
|
|
|
|
self.params_per_rank_id_dict.append([[] for _ in range(self._zero_world_size[group_id])])
|
|
|
|
|
param_list = param_group["params"]
|
|
|
|
|
|
|
|
|
|
if self._is_moe_group(param_group):
|
|
|
|
|
# for moe group, we do not need to partition the params, just add current
|
|
|
|
|
# params to params_per_rank[_zero_local_rank]
|
|
|
|
|
params_per_rank[self._zero_local_rank] = list(param_list)
|
|
|
|
|
self.params_per_rank_id_dict[-1][self._zero_local_rank].append(None)
|
|
|
|
|
no_params_ranks = list(range(self._zero_world_size))
|
|
|
|
|
no_params_ranks.pop(self._zero_local_rank)
|
|
|
|
|
sorted_params = sorted(param_list, key=lambda x: x.numel(), reverse=True)
|
|
|
|
|
for i, param in enumerate(sorted_params):
|
|
|
|
|
global_id = str(i)
|
|
|
|
|
for j in range(len(param.size())):
|
|
|
|
|
global_id = "_".join([global_id, str(param.size()[j])])
|
|
|
|
|
if self._overlap_sync_param:
|
|
|
|
|
assert not hasattr(gpc.config.model, "num_experts")
|
|
|
|
|
rank_to_go = self._param_bcast_sync_handler.get_rank_by_param(param)
|
|
|
|
|
else:
|
|
|
|
|
rank_to_go = numel_per_rank.index(min(numel_per_rank))
|
|
|
|
|
params_per_rank[rank_to_go].append(param)
|
|
|
|
|
self.params_per_rank_id_dict[-1][rank_to_go].append(global_id)
|
|
|
|
|
numel_per_rank[rank_to_go] += param.numel()
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
sorted_params = sorted(param_list, key=lambda x: x.numel(), reverse=True)
|
|
|
|
|
for i, param in enumerate(sorted_params):
|
|
|
|
|
global_id = str(i)
|
|
|
|
|
for j in range(len(param.size())):
|
|
|
|
|
global_id = "_".join([global_id, str(param.size()[j])])
|
|
|
|
|
if self._overlap_sync_param:
|
|
|
|
|
rank_to_go = self._param_bcast_sync_handler.get_rank_by_param(param)
|
|
|
|
|
else:
|
|
|
|
|
rank_to_go = numel_per_rank.index(min(numel_per_rank))
|
|
|
|
|
params_per_rank[rank_to_go].append(param)
|
|
|
|
|
self.params_per_rank_id_dict[-1][rank_to_go].append(global_id)
|
|
|
|
|
numel_per_rank[rank_to_go] += param.numel()
|
|
|
|
|
# check whether any rank is not assigned to parameters.
|
|
|
|
|
for rank, params in enumerate(params_per_rank):
|
|
|
|
|
if len(params) == 0:
|
|
|
|
|
no_params_ranks.append(rank)
|
|
|
|
|
|
|
|
|
|
# check whether any rank is not assigned to parameters.
|
|
|
|
|
for rank, params in enumerate(params_per_rank):
|
|
|
|
|
if len(params) == 0:
|
|
|
|
|
no_params_ranks.append(rank)
|
|
|
|
|
|
|
|
|
|
if gpc.is_rank_for_log():
|
|
|
|
|
logger.info( # pylint: disable=W1203
|
|
|
|
|
f"Number of elements on ranks: {numel_per_rank}, rank:{gpc.get_global_rank()}"
|
|
|
|
|
)
|
|
|
|
|
if gpc.is_rank_for_log():
|
|
|
|
|
logger.info( # pylint: disable=W1203
|
|
|
|
|
f"Number of elements on ranks: {numel_per_rank}, rank:{gpc.get_global_rank()}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return params_per_rank, set(no_params_ranks)
|
|
|
|
|
|
|
|
|
@ -315,6 +317,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|
|
|
|
def _is_gate_group(self, param_group):
|
|
|
|
|
return "gate" in param_group.keys() and param_group["gate"]
|
|
|
|
|
|
|
|
|
|
# TODO check expert dp is correct when enable moe and overlap both
|
|
|
|
|
def _attach_reduction_hook(self):
|
|
|
|
|
# we iterate over the fp16 params
|
|
|
|
|
# on each param, we register a hook to its AccumulateGrad object
|
|
|
|
@ -348,16 +351,28 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|
|
|
|
|
|
|
|
|
_define_and_attach(param, reduce_rank)
|
|
|
|
|
|
|
|
|
|
def belongs_to_current_rank(self, param) -> bool:
|
|
|
|
|
"""
|
|
|
|
|
Check whether a parameter is supposed to be updated by the process of the current rank
|
|
|
|
|
|
|
|
|
|
:param tensor: A :class:`torch.Tensor` object
|
|
|
|
|
:type tensor: torch.Tensor
|
|
|
|
|
|
|
|
|
|
:return: True if the parameter should be updated by the current rank. Otherwise false.
|
|
|
|
|
:rtype: bool
|
|
|
|
|
"""
|
|
|
|
|
tensor_rank = self._param_store.get_param_rank(param)
|
|
|
|
|
group_id = getattr(param, "group_id")
|
|
|
|
|
return tensor_rank == gpc.get_local_rank(self._broadcast_parallel_mode[group_id])
|
|
|
|
|
|
|
|
|
|
def _store_and_try_reduce_grads_by_bucket(self, param, reduce_rank=None):
|
|
|
|
|
param_size = param.numel()
|
|
|
|
|
|
|
|
|
|
# check if the bucket is full
|
|
|
|
|
# if full, will reduce the grads already in the bucket
|
|
|
|
|
# after reduction, the bucket will be empty
|
|
|
|
|
if is_moe_param(param):
|
|
|
|
|
current_bucket = self._moe_bucket_store
|
|
|
|
|
else:
|
|
|
|
|
current_bucket = self._non_moe_bucket_store
|
|
|
|
|
group_id = getattr(param, "group_id")
|
|
|
|
|
current_bucket = self._bucket_store[group_id]
|
|
|
|
|
|
|
|
|
|
if current_bucket.num_elements_in_bucket(reduce_rank) + param_size > self._reduce_bucket_size:
|
|
|
|
|
self._reduce_grads_stored_in_bucket(current_bucket, reduce_rank, last_bucket=False)
|
|
|
|
@ -384,6 +399,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|
|
|
|
reduce_rank=reduce_rank,
|
|
|
|
|
grads=current_bucket.get_grad(reduce_rank=reduce_rank),
|
|
|
|
|
bucket_size=current_bucket.num_elements_in_bucket(reduce_rank),
|
|
|
|
|
group_id=current_bucket.get_param_group_id(),
|
|
|
|
|
dp_parallel_mode=current_bucket.get_dp_parallel_mode(),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
@ -404,14 +420,14 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|
|
|
|
# update the flag
|
|
|
|
|
self._param_store.set_param_reduction_state(param, True)
|
|
|
|
|
|
|
|
|
|
if self._param_store.belongs_to_current_rank(param):
|
|
|
|
|
if self.belongs_to_current_rank(param):
|
|
|
|
|
self._param_store.add_reduced_param_for_compute_norm(param, last_bucket)
|
|
|
|
|
else:
|
|
|
|
|
self._param_store.add_previous_reduced_param(param)
|
|
|
|
|
|
|
|
|
|
current_bucket.reset_by_rank(reduce_rank)
|
|
|
|
|
|
|
|
|
|
def _reduce_grads_by_rank(self, reduce_rank, grads, bucket_size, dp_parallel_mode):
|
|
|
|
|
def _reduce_grads_by_rank(self, reduce_rank, grads, bucket_size, group_id, dp_parallel_mode):
|
|
|
|
|
grad_buckets_by_dtype = split_half_float_double(grads)
|
|
|
|
|
next_bucket_list = []
|
|
|
|
|
# add parameters into bucket for reduction
|
|
|
|
@ -420,7 +436,9 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|
|
|
|
for tensor in tensor_list:
|
|
|
|
|
param_bucket.add_to_bucket(tensor, allow_oversize=True)
|
|
|
|
|
if not param_bucket.is_empty():
|
|
|
|
|
self._reduce_and_copy(bucket=param_bucket, reduce_rank=reduce_rank, dp_parallel_mode=dp_parallel_mode)
|
|
|
|
|
self._reduce_and_copy(
|
|
|
|
|
bucket=param_bucket, reduce_rank=reduce_rank, group_id=group_id, dp_parallel_mode=dp_parallel_mode
|
|
|
|
|
)
|
|
|
|
|
next_bucket_list.append(param_bucket)
|
|
|
|
|
|
|
|
|
|
# wait for the completion of previouce bucket list reduction, and do unflatten_and_copy()
|
|
|
|
@ -435,7 +453,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|
|
|
|
# after the completion of bucket list reduction, add new buckets into _bucket_in_progress
|
|
|
|
|
self._bucket_in_progress = next_bucket_list.copy()
|
|
|
|
|
|
|
|
|
|
def _reduce_and_copy(self, bucket: TensorBucket, reduce_rank, dp_parallel_mode):
|
|
|
|
|
def _reduce_and_copy(self, bucket: TensorBucket, reduce_rank, group_id, dp_parallel_mode):
|
|
|
|
|
# flatten the tensors and do allreduce
|
|
|
|
|
bucket.flatten()
|
|
|
|
|
bucket.commu_handle = reduce_tensor(
|
|
|
|
@ -446,7 +464,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# update the reduced tensor
|
|
|
|
|
if reduce_rank is None or reduce_rank == self._zero_local_rank:
|
|
|
|
|
if reduce_rank is None or reduce_rank == self._zero_local_rank[group_id]:
|
|
|
|
|
bucket.set_unflatten_and_copy_flag(flag=True)
|
|
|
|
|
|
|
|
|
|
def _has_inf_or_nan(self, tensor):
|
|
|
|
@ -475,8 +493,8 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|
|
|
|
avg_gradients = self._grad_store._averaged_gradients
|
|
|
|
|
for group_id in range(self.num_param_groups):
|
|
|
|
|
# the following operations are performed only on the rank to which parameters are assigned.
|
|
|
|
|
if self._zero_local_rank not in self.param_group_no_params_ranks[group_id]:
|
|
|
|
|
param_group = self._param_store.get_fp16_params_by_rank_group(self._zero_local_rank, group_id)
|
|
|
|
|
if self._zero_local_rank[group_id] not in self.param_group_no_params_ranks[group_id]:
|
|
|
|
|
param_group = self._param_store.get_fp16_params_by_rank_group(self._zero_local_rank[group_id], group_id)
|
|
|
|
|
|
|
|
|
|
if group_id not in avg_gradients:
|
|
|
|
|
avg_gradients[group_id] = []
|
|
|
|
@ -539,37 +557,11 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|
|
|
|
parameters=params,
|
|
|
|
|
last_stage=last_stage,
|
|
|
|
|
previous_norm=previous_norm,
|
|
|
|
|
zero_mode=self._broadcast_parallel_mode[group_id],
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return norm
|
|
|
|
|
|
|
|
|
|
def _compute_norm_with_moe_group(self, group_id):
|
|
|
|
|
params = self._param_store.get_fp16_params_by_rank_group(group_id=group_id, rank=self._zero_local_rank)
|
|
|
|
|
# we do not get the average grad for moe parameters, so we have to constuct the gradients list here.
|
|
|
|
|
grads = [p.grad for p in params]
|
|
|
|
|
|
|
|
|
|
if len(params) == 0:
|
|
|
|
|
grads = [self.padding_grad]
|
|
|
|
|
params = [self.padding_tensor]
|
|
|
|
|
|
|
|
|
|
norm = 0
|
|
|
|
|
if self._clip_grad_norm > 0:
|
|
|
|
|
norm = compute_norm(
|
|
|
|
|
gradients=grads,
|
|
|
|
|
parameters=params,
|
|
|
|
|
last_stage=True,
|
|
|
|
|
is_moe_group=True,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Need to allreduce(avg) the norms across different ranks because moe params will not be synced during allreduce
|
|
|
|
|
# model and zero have been reduced!!!
|
|
|
|
|
pg = gpc.get_group(ParallelMode.DATA)
|
|
|
|
|
scaled_norm = norm * 1.0 / float(gpc.get_world_size(ParallelMode.DATA))
|
|
|
|
|
scaled_norm_tensor = torch.tensor(scaled_norm, device=get_current_device(), dtype=torch.float)
|
|
|
|
|
dist.all_reduce(scaled_norm_tensor, group=pg)
|
|
|
|
|
all_groups_norm = scaled_norm_tensor.item()
|
|
|
|
|
return all_groups_norm
|
|
|
|
|
|
|
|
|
|
@llm_timeout(func_name="optim_step")
|
|
|
|
|
def step(self, closure=None):
|
|
|
|
|
"""Performs a single optimization step.
|
|
|
|
@ -592,16 +584,13 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|
|
|
|
self._store_and_try_reduce_grads_by_bucket(param)
|
|
|
|
|
|
|
|
|
|
# we need to reduce the gradients left in the communication bucket
|
|
|
|
|
self._reduce_grads_stored_in_bucket(self._non_moe_bucket_store, reduce_rank=None, last_bucket=True)
|
|
|
|
|
self._reduce_grads_stored_in_bucket(self._moe_bucket_store, reduce_rank=None, last_bucket=True)
|
|
|
|
|
for group_id in range(self.num_param_groups):
|
|
|
|
|
self._reduce_grads_stored_in_bucket(self._bucket_store[group_id], reduce_rank=None, last_bucket=True)
|
|
|
|
|
|
|
|
|
|
# compute norm for gradients in the before bucket
|
|
|
|
|
groups_norms = []
|
|
|
|
|
for group_id in range(self.num_param_groups):
|
|
|
|
|
if self._is_moe_group(self.optim.param_groups[group_id]):
|
|
|
|
|
groups_norms.append(None)
|
|
|
|
|
else:
|
|
|
|
|
groups_norms.append(self._compute_norm_with_stage(group_id=group_id))
|
|
|
|
|
groups_norms.append(self._compute_norm_with_stage(group_id=group_id))
|
|
|
|
|
|
|
|
|
|
# clear reduced grads
|
|
|
|
|
# grads in the last bucket is reduced
|
|
|
|
@ -617,15 +606,22 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|
|
|
|
for group_id in range(self.num_param_groups):
|
|
|
|
|
group_name = self.param_groups[group_id]["name"] if "name" in self.param_groups[group_id] else "default"
|
|
|
|
|
group_name = f"{group_id}_{group_name}"
|
|
|
|
|
total_norms[group_name] = self._compute_norm_with_stage(
|
|
|
|
|
group_id=group_id,
|
|
|
|
|
last_bucket=True,
|
|
|
|
|
last_stage=True,
|
|
|
|
|
previous_norm=groups_norms[group_id],
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Need to allreduce(avg) the norms across different ranks because moe params will not be synced
|
|
|
|
|
# during allreduce
|
|
|
|
|
if self._is_moe_group(self.optim.param_groups[group_id]):
|
|
|
|
|
total_norms[group_name] = self._compute_norm_with_moe_group(group_id=group_id)
|
|
|
|
|
else:
|
|
|
|
|
total_norms[group_name] = self._compute_norm_with_stage(
|
|
|
|
|
group_id=group_id,
|
|
|
|
|
last_bucket=True,
|
|
|
|
|
last_stage=True,
|
|
|
|
|
previous_norm=groups_norms[group_id],
|
|
|
|
|
)
|
|
|
|
|
# model and zero have been reduced!!!
|
|
|
|
|
pg = gpc.get_group(ParallelMode.EXPERT)
|
|
|
|
|
scaled_norm = total_norms[group_name] * 1.0 / float(gpc.get_world_size(ParallelMode.EXPERT))
|
|
|
|
|
scaled_norm_tensor = torch.tensor(scaled_norm, device=get_current_device(), dtype=torch.float)
|
|
|
|
|
dist.all_reduce(scaled_norm_tensor, group=pg)
|
|
|
|
|
total_norms[group_name] = scaled_norm_tensor.item()
|
|
|
|
|
|
|
|
|
|
timer("sync_grad").start()
|
|
|
|
|
self._sync_grad()
|
|
|
|
@ -747,7 +743,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|
|
|
|
for group_id in range(len(self._fp16_param_groups)):
|
|
|
|
|
if self.param_group_has_params[group_id]:
|
|
|
|
|
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(
|
|
|
|
|
rank=self._zero_local_rank, group_id=group_id
|
|
|
|
|
rank=self._zero_local_rank[group_id], group_id=group_id
|
|
|
|
|
)
|
|
|
|
|
fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id]
|
|
|
|
|
fp16_param.data.copy_(fp32_param)
|
|
|
|
@ -767,27 +763,26 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|
|
|
|
def broadcast_params(self):
|
|
|
|
|
handles = []
|
|
|
|
|
|
|
|
|
|
for rank, group_id in product(range(self._zero_world_size), range(self.num_param_groups)):
|
|
|
|
|
if self._is_moe_group(self.optim.param_groups[group_id]):
|
|
|
|
|
continue
|
|
|
|
|
# The following operations are performed only on the rank to which parameters are assigned.
|
|
|
|
|
if rank in self.param_group_no_params_ranks[group_id]:
|
|
|
|
|
continue
|
|
|
|
|
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(rank=rank, group_id=group_id)
|
|
|
|
|
# grank = gpc.get_ranks_in_group(group_type)[rank] # need to convert to the global rank
|
|
|
|
|
# assert grank == rank, f"{grank} == {rank}"
|
|
|
|
|
g_rank = gpc.get_ranks_in_group(self._broadcast_parallel_mode)[rank]
|
|
|
|
|
handle = dist.broadcast(
|
|
|
|
|
fp16_param,
|
|
|
|
|
src=g_rank,
|
|
|
|
|
group=gpc.get_group(ParallelMode.ZERO1),
|
|
|
|
|
async_op=True,
|
|
|
|
|
)
|
|
|
|
|
for group_id in range(self.num_param_groups):
|
|
|
|
|
for rank in range(self._zero_world_size[group_id]):
|
|
|
|
|
# The following operations are performed only on the rank to which parameters are assigned.
|
|
|
|
|
if rank in self.param_group_no_params_ranks[group_id]:
|
|
|
|
|
continue
|
|
|
|
|
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(rank=rank, group_id=group_id)
|
|
|
|
|
# grank = gpc.get_ranks_in_group(group_type)[rank] # need to convert to the global rank
|
|
|
|
|
# assert grank == rank, f"{grank} == {rank}"
|
|
|
|
|
g_rank = gpc.get_ranks_in_group(self._broadcast_parallel_mode[group_id])[rank]
|
|
|
|
|
handle = dist.broadcast(
|
|
|
|
|
fp16_param,
|
|
|
|
|
src=g_rank,
|
|
|
|
|
group=gpc.get_group(self._broadcast_parallel_mode[group_id]),
|
|
|
|
|
async_op=True,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if self._overlap_sync_param:
|
|
|
|
|
self._param_bcast_sync_handler.add_bcast_handle(rank, handle)
|
|
|
|
|
else:
|
|
|
|
|
handles.append(handle)
|
|
|
|
|
if self._overlap_sync_param:
|
|
|
|
|
self._param_bcast_sync_handler.add_bcast_handle(rank, handle)
|
|
|
|
|
else:
|
|
|
|
|
handles.append(handle)
|
|
|
|
|
|
|
|
|
|
for handle in handles:
|
|
|
|
|
handle.wait()
|
|
|
|
@ -803,7 +798,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|
|
|
|
# check for overflow
|
|
|
|
|
for group_id in range(len(self._fp16_param_groups)):
|
|
|
|
|
# The following operations are performed only on the rank to which parameters are assigned.
|
|
|
|
|
if self._zero_local_rank not in self.param_group_no_params_ranks[group_id]:
|
|
|
|
|
if self._zero_local_rank[group_id] not in self.param_group_no_params_ranks[group_id]:
|
|
|
|
|
for avg_grad in self._grad_store.get_averaged_gradients_by_group(group_id):
|
|
|
|
|
if avg_grad is not None and has_inf_or_nan(avg_grad):
|
|
|
|
|
self._found_overflow.fill_(1.0)
|
|
|
|
@ -844,7 +839,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|
|
|
|
|
|
|
|
|
flat_fp32_weights = {}
|
|
|
|
|
for group_id, param in self._fp32_flat_param_groups_of_current_rank.items():
|
|
|
|
|
if self._zero_local_rank not in self.param_group_no_params_ranks[group_id]:
|
|
|
|
|
if self._zero_local_rank[group_id] not in self.param_group_no_params_ranks[group_id]:
|
|
|
|
|
assert param.grad is None
|
|
|
|
|
flat_fp32_weights[group_id] = param
|
|
|
|
|
states["flat_fp32_weights"] = flat_fp32_weights
|
|
|
|
@ -864,7 +859,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|
|
|
|
flat_fp32_weights = states["flat_fp32_weights"]
|
|
|
|
|
assert set(flat_fp32_weights.keys()) == set(self._fp32_flat_param_groups_of_current_rank)
|
|
|
|
|
for group_id, param in flat_fp32_weights.items():
|
|
|
|
|
if self._zero_local_rank not in self.param_group_no_params_ranks[group_id]:
|
|
|
|
|
if self._zero_local_rank[group_id] not in self.param_group_no_params_ranks[group_id]:
|
|
|
|
|
self_param = self._fp32_flat_param_groups_of_current_rank[group_id]
|
|
|
|
|
assert (
|
|
|
|
|
self_param.shape == param.shape
|
|
|
|
@ -873,9 +868,9 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|
|
|
|
|
|
|
|
|
# Load the fp16 model weights.
|
|
|
|
|
for group_id in range(len(self._fp16_param_groups)):
|
|
|
|
|
if self._zero_local_rank not in self.param_group_no_params_ranks[group_id]:
|
|
|
|
|
if self._zero_local_rank[group_id] not in self.param_group_no_params_ranks[group_id]:
|
|
|
|
|
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(
|
|
|
|
|
rank=self._zero_local_rank, group_id=group_id
|
|
|
|
|
rank=self._zero_local_rank[group_id], group_id=group_id
|
|
|
|
|
)
|
|
|
|
|
fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id]
|
|
|
|
|
fp16_param.data.copy_(fp32_param)
|
|
|
|
@ -892,7 +887,7 @@ def reload_zero_fp32_buff(optimizer):
|
|
|
|
|
if optimizer.param_group_has_params[group_id]:
|
|
|
|
|
# flatten fp16 params have already been updated by 'load_model_checkpoint'
|
|
|
|
|
fp16_flat_current_rank = optimizer._param_store.get_flat_fp16_param_by_rank_group(
|
|
|
|
|
optimizer._zero_local_rank, group_id
|
|
|
|
|
optimizer._zero_local_rank[group_id], group_id
|
|
|
|
|
)
|
|
|
|
|
# param_group["params"] is fp32 flatten optimizer states of this zero rank.
|
|
|
|
|
param_group["params"][0].data.copy_(fp16_flat_current_rank.float())
|
|
|
|
|