mirror of https://github.com/InternLM/InternLM
feat(moe):support zero for expert local dp (#404)
* support zero for expert local dp * fix above codes: *treat optim.zero_world_size and optim.zero_local_rank as list in model_checkpoint.py and test_model_checkpoint.py *add overlap and zero check for moe in args_sanity_check(.)pull/407/head
parent
916647c0a1
commit
582ee000bd
|
@ -319,6 +319,13 @@ def args_sanity_check():
|
||||||
if "moe_loss_coeff" not in gpc.config.loss:
|
if "moe_loss_coeff" not in gpc.config.loss:
|
||||||
gpc.config.loss._add_item("moe_loss_coeff", 1.0)
|
gpc.config.loss._add_item("moe_loss_coeff", 1.0)
|
||||||
|
|
||||||
|
# moe not support overlap and zero1.5 for now
|
||||||
|
if hasattr(gpc.config.model, "num_experts"):
|
||||||
|
assert (
|
||||||
|
not optim_ckpt.overlap_sync_grad & optim_ckpt.overlap_sync_param
|
||||||
|
), "not support overlap and moe at the same time"
|
||||||
|
assert gpc.config.parallel.zero1 == -1, "moe only support zero1, set zero1=-1 can fix this"
|
||||||
|
|
||||||
|
|
||||||
def launch(
|
def launch(
|
||||||
config: Union[str, Path, Config, Dict],
|
config: Union[str, Path, Config, Dict],
|
||||||
|
|
|
@ -3,7 +3,6 @@
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from itertools import product
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
@ -11,7 +10,6 @@ from torch.optim import Optimizer
|
||||||
|
|
||||||
from internlm.core.context import Config, ParallelMode
|
from internlm.core.context import Config, ParallelMode
|
||||||
from internlm.core.context import global_context as gpc
|
from internlm.core.context import global_context as gpc
|
||||||
from internlm.model.utils import is_moe_param
|
|
||||||
from internlm.monitor import send_alert_message
|
from internlm.monitor import send_alert_message
|
||||||
from internlm.solver.optimizer.store import (
|
from internlm.solver.optimizer.store import (
|
||||||
BucketStore,
|
BucketStore,
|
||||||
|
@ -116,16 +114,15 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
super().__init__(optim=optimizer)
|
super().__init__(optim=optimizer)
|
||||||
|
|
||||||
self._cpu_offload = cpu_offload
|
self._cpu_offload = cpu_offload
|
||||||
self._zero_local_rank = gpc.get_local_rank(ParallelMode.ZERO1)
|
self._zero_local_rank = []
|
||||||
self._zero_world_size = gpc.get_world_size(ParallelMode.ZERO1)
|
self._zero_world_size = []
|
||||||
self._broadcast_parallel_mode = ParallelMode.ZERO1
|
self._broadcast_parallel_mode = []
|
||||||
|
|
||||||
# ParameterStore will manage the tensor buffers used for zero
|
# ParameterStore will manage the tensor buffers used for zero
|
||||||
# it will not manage the tensors used by mixed precision training
|
# it will not manage the tensors used by mixed precision training
|
||||||
self._param_store = ParameterStore(ParallelMode.ZERO1)
|
self._param_store = ParameterStore(ParallelMode.ZERO1)
|
||||||
self._grad_store = GradientStore(ParallelMode.DATA)
|
self._grad_store = GradientStore(ParallelMode.DATA)
|
||||||
self._non_moe_bucket_store = BucketStore(ParallelMode.DATA)
|
self._bucket_store = []
|
||||||
self._moe_bucket_store = BucketStore(ParallelMode.EXPERT_DATA)
|
|
||||||
self._bucket_in_progress = []
|
self._bucket_in_progress = []
|
||||||
|
|
||||||
# fp16 and fp32 params for mixed precision training
|
# fp16 and fp32 params for mixed precision training
|
||||||
|
@ -163,7 +160,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
f"gpus-{gpc.get_world_size(ParallelMode.GLOBAL)}_"
|
f"gpus-{gpc.get_world_size(ParallelMode.GLOBAL)}_"
|
||||||
+ f"pp-{gpc.get_local_rank(ParallelMode.PIPELINE)}_"
|
+ f"pp-{gpc.get_local_rank(ParallelMode.PIPELINE)}_"
|
||||||
+ f"tp-{gpc.get_local_rank(ParallelMode.TENSOR)}_"
|
+ f"tp-{gpc.get_local_rank(ParallelMode.TENSOR)}_"
|
||||||
+ f"zo-{self._zero_local_rank}.pt"
|
+ f"zo-{gpc.get_local_rank(ParallelMode.ZERO1)}.pt"
|
||||||
)
|
)
|
||||||
self.params_per_rank_id_dict = []
|
self.params_per_rank_id_dict = []
|
||||||
self._param_bcast_sync_handler = param_bcast_sync_handler
|
self._param_bcast_sync_handler = param_bcast_sync_handler
|
||||||
|
@ -182,10 +179,23 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
# add the fp16 params to fp16_param_groups for bookkeeping
|
# add the fp16 params to fp16_param_groups for bookkeeping
|
||||||
self._fp16_param_groups[group_id] = group_params
|
self._fp16_param_groups[group_id] = group_params
|
||||||
|
|
||||||
|
# to find real zero mode. if zero is not used, set all param group as ParallelMode.ZERO1
|
||||||
|
# if zero is used, expert dp group will use ParallelMode.EXPERT_DATA as the real zero mode
|
||||||
|
zero_mode = (
|
||||||
|
ParallelMode.ZERO1
|
||||||
|
if param_group["dp_mode"] == gpc.get_world_size(ParallelMode.ZERO1) == 1 or ParallelMode.DATA
|
||||||
|
else ParallelMode.EXPERT_DATA
|
||||||
|
)
|
||||||
|
self._zero_local_rank.append(gpc.get_local_rank(zero_mode))
|
||||||
|
self._zero_world_size.append(gpc.get_world_size(zero_mode))
|
||||||
|
# TODO _broadcast_parallel_mode is not only used in broadcast, maybe can change its name
|
||||||
|
self._broadcast_parallel_mode.append(zero_mode)
|
||||||
|
self._bucket_store.append(BucketStore(group_id, param_group["dp_mode"]))
|
||||||
|
|
||||||
# assign parameters to ranks the params in the list are sorted
|
# assign parameters to ranks the params in the list are sorted
|
||||||
params_per_rank, no_params_ranks = self._partition_param_list(param_group)
|
params_per_rank, no_params_ranks = self._partition_param_list(group_id, param_group)
|
||||||
self.param_group_no_params_ranks.append(no_params_ranks)
|
self.param_group_no_params_ranks.append(no_params_ranks)
|
||||||
self.param_group_has_params.append(self._zero_local_rank not in no_params_ranks)
|
self.param_group_has_params.append(self._zero_local_rank[group_id] not in no_params_ranks)
|
||||||
|
|
||||||
# store the mapping between param to rank each param should belong to only one rank.
|
# store the mapping between param to rank each param should belong to only one rank.
|
||||||
# we can skip the moe param and do not keep them in _param_store to save memory
|
# we can skip the moe param and do not keep them in _param_store to save memory
|
||||||
|
@ -204,7 +214,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
param.data = param.data.cpu()
|
param.data = param.data.cpu()
|
||||||
|
|
||||||
# flatten the reordered tensors
|
# flatten the reordered tensors
|
||||||
for rank in range(self._zero_world_size):
|
for rank in range(self._zero_world_size[group_id]):
|
||||||
# No flat fp16 buffer is allocated if the process has no parameters.
|
# No flat fp16 buffer is allocated if the process has no parameters.
|
||||||
if rank not in self.param_group_no_params_ranks[group_id]:
|
if rank not in self.param_group_no_params_ranks[group_id]:
|
||||||
tensor_list = self._param_store.get_fp16_params_by_rank_group(rank, group_id)
|
tensor_list = self._param_store.get_fp16_params_by_rank_group(rank, group_id)
|
||||||
|
@ -218,7 +228,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
# No flat fp32 buffer is allocated if the process has no parameters.
|
# No flat fp32 buffer is allocated if the process has no parameters.
|
||||||
if self.param_group_has_params[group_id]:
|
if self.param_group_has_params[group_id]:
|
||||||
fp16_flat_current_rank = self._param_store.get_flat_fp16_param_by_rank_group(
|
fp16_flat_current_rank = self._param_store.get_flat_fp16_param_by_rank_group(
|
||||||
self._zero_local_rank, group_id
|
self._zero_local_rank[group_id], group_id
|
||||||
)
|
)
|
||||||
fp32_flat_current_rank = fp16_flat_current_rank.float()
|
fp32_flat_current_rank = fp16_flat_current_rank.float()
|
||||||
device = "cpu" if self._cpu_offload else get_current_device()
|
device = "cpu" if self._cpu_offload else get_current_device()
|
||||||
|
@ -263,44 +273,35 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
def num_param_groups(self):
|
def num_param_groups(self):
|
||||||
return len(self._fp16_param_groups)
|
return len(self._fp16_param_groups)
|
||||||
|
|
||||||
def _partition_param_list(self, param_group):
|
def _partition_param_list(self, group_id, param_group):
|
||||||
no_params_ranks = []
|
no_params_ranks = []
|
||||||
params_per_rank = [[] for _ in range(self._zero_world_size)]
|
params_per_rank = [[] for _ in range(self._zero_world_size[group_id])]
|
||||||
numel_per_rank = [0 for _ in range(self._zero_world_size)]
|
numel_per_rank = [0 for _ in range(self._zero_world_size[group_id])]
|
||||||
self.params_per_rank_id_dict.append([[] for _ in range(self._zero_world_size)])
|
self.params_per_rank_id_dict.append([[] for _ in range(self._zero_world_size[group_id])])
|
||||||
param_list = param_group["params"]
|
param_list = param_group["params"]
|
||||||
|
|
||||||
if self._is_moe_group(param_group):
|
sorted_params = sorted(param_list, key=lambda x: x.numel(), reverse=True)
|
||||||
# for moe group, we do not need to partition the params, just add current
|
for i, param in enumerate(sorted_params):
|
||||||
# params to params_per_rank[_zero_local_rank]
|
global_id = str(i)
|
||||||
params_per_rank[self._zero_local_rank] = list(param_list)
|
for j in range(len(param.size())):
|
||||||
self.params_per_rank_id_dict[-1][self._zero_local_rank].append(None)
|
global_id = "_".join([global_id, str(param.size()[j])])
|
||||||
no_params_ranks = list(range(self._zero_world_size))
|
if self._overlap_sync_param:
|
||||||
no_params_ranks.pop(self._zero_local_rank)
|
rank_to_go = self._param_bcast_sync_handler.get_rank_by_param(param)
|
||||||
|
else:
|
||||||
|
rank_to_go = numel_per_rank.index(min(numel_per_rank))
|
||||||
|
params_per_rank[rank_to_go].append(param)
|
||||||
|
self.params_per_rank_id_dict[-1][rank_to_go].append(global_id)
|
||||||
|
numel_per_rank[rank_to_go] += param.numel()
|
||||||
|
|
||||||
else:
|
# check whether any rank is not assigned to parameters.
|
||||||
sorted_params = sorted(param_list, key=lambda x: x.numel(), reverse=True)
|
for rank, params in enumerate(params_per_rank):
|
||||||
for i, param in enumerate(sorted_params):
|
if len(params) == 0:
|
||||||
global_id = str(i)
|
no_params_ranks.append(rank)
|
||||||
for j in range(len(param.size())):
|
|
||||||
global_id = "_".join([global_id, str(param.size()[j])])
|
|
||||||
if self._overlap_sync_param:
|
|
||||||
rank_to_go = self._param_bcast_sync_handler.get_rank_by_param(param)
|
|
||||||
else:
|
|
||||||
rank_to_go = numel_per_rank.index(min(numel_per_rank))
|
|
||||||
params_per_rank[rank_to_go].append(param)
|
|
||||||
self.params_per_rank_id_dict[-1][rank_to_go].append(global_id)
|
|
||||||
numel_per_rank[rank_to_go] += param.numel()
|
|
||||||
|
|
||||||
# check whether any rank is not assigned to parameters.
|
if gpc.is_rank_for_log():
|
||||||
for rank, params in enumerate(params_per_rank):
|
logger.info( # pylint: disable=W1203
|
||||||
if len(params) == 0:
|
f"Number of elements on ranks: {numel_per_rank}, rank:{gpc.get_global_rank()}"
|
||||||
no_params_ranks.append(rank)
|
)
|
||||||
|
|
||||||
if gpc.is_rank_for_log():
|
|
||||||
logger.info( # pylint: disable=W1203
|
|
||||||
f"Number of elements on ranks: {numel_per_rank}, rank:{gpc.get_global_rank()}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return params_per_rank, set(no_params_ranks)
|
return params_per_rank, set(no_params_ranks)
|
||||||
|
|
||||||
|
@ -313,6 +314,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
def _is_gate_group(self, param_group):
|
def _is_gate_group(self, param_group):
|
||||||
return "gate" in param_group.keys() and param_group["gate"]
|
return "gate" in param_group.keys() and param_group["gate"]
|
||||||
|
|
||||||
|
# TODO check expert dp is correct when enable moe and overlap both
|
||||||
def _attach_reduction_hook(self):
|
def _attach_reduction_hook(self):
|
||||||
# we iterate over the fp16 params
|
# we iterate over the fp16 params
|
||||||
# on each param, we register a hook to its AccumulateGrad object
|
# on each param, we register a hook to its AccumulateGrad object
|
||||||
|
@ -346,16 +348,28 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
|
|
||||||
_define_and_attach(param, reduce_rank)
|
_define_and_attach(param, reduce_rank)
|
||||||
|
|
||||||
|
def belongs_to_current_rank(self, param) -> bool:
|
||||||
|
"""
|
||||||
|
Check whether a parameter is supposed to be updated by the process of the current rank
|
||||||
|
|
||||||
|
:param tensor: A :class:`torch.Tensor` object
|
||||||
|
:type tensor: torch.Tensor
|
||||||
|
|
||||||
|
:return: True if the parameter should be updated by the current rank. Otherwise false.
|
||||||
|
:rtype: bool
|
||||||
|
"""
|
||||||
|
tensor_rank = self._param_store.get_param_rank(param)
|
||||||
|
group_id = getattr(param, "group_id")
|
||||||
|
return tensor_rank == gpc.get_local_rank(self._broadcast_parallel_mode[group_id])
|
||||||
|
|
||||||
def _store_and_try_reduce_grads_by_bucket(self, param, reduce_rank=None):
|
def _store_and_try_reduce_grads_by_bucket(self, param, reduce_rank=None):
|
||||||
param_size = param.numel()
|
param_size = param.numel()
|
||||||
|
|
||||||
# check if the bucket is full
|
# check if the bucket is full
|
||||||
# if full, will reduce the grads already in the bucket
|
# if full, will reduce the grads already in the bucket
|
||||||
# after reduction, the bucket will be empty
|
# after reduction, the bucket will be empty
|
||||||
if is_moe_param(param):
|
group_id = getattr(param, "group_id")
|
||||||
current_bucket = self._moe_bucket_store
|
current_bucket = self._bucket_store[group_id]
|
||||||
else:
|
|
||||||
current_bucket = self._non_moe_bucket_store
|
|
||||||
|
|
||||||
if current_bucket.num_elements_in_bucket(reduce_rank) + param_size > self._reduce_bucket_size:
|
if current_bucket.num_elements_in_bucket(reduce_rank) + param_size > self._reduce_bucket_size:
|
||||||
self._reduce_grads_stored_in_bucket(current_bucket, reduce_rank, last_bucket=False)
|
self._reduce_grads_stored_in_bucket(current_bucket, reduce_rank, last_bucket=False)
|
||||||
|
@ -382,6 +396,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
reduce_rank=reduce_rank,
|
reduce_rank=reduce_rank,
|
||||||
grads=current_bucket.get_grad(reduce_rank=reduce_rank),
|
grads=current_bucket.get_grad(reduce_rank=reduce_rank),
|
||||||
bucket_size=current_bucket.num_elements_in_bucket(reduce_rank),
|
bucket_size=current_bucket.num_elements_in_bucket(reduce_rank),
|
||||||
|
group_id=current_bucket.get_param_group_id(),
|
||||||
dp_parallel_mode=current_bucket.get_dp_parallel_mode(),
|
dp_parallel_mode=current_bucket.get_dp_parallel_mode(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -402,14 +417,14 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
# update the flag
|
# update the flag
|
||||||
self._param_store.set_param_reduction_state(param, True)
|
self._param_store.set_param_reduction_state(param, True)
|
||||||
|
|
||||||
if self._param_store.belongs_to_current_rank(param):
|
if self.belongs_to_current_rank(param):
|
||||||
self._param_store.add_reduced_param_for_compute_norm(param, last_bucket)
|
self._param_store.add_reduced_param_for_compute_norm(param, last_bucket)
|
||||||
else:
|
else:
|
||||||
self._param_store.add_previous_reduced_param(param)
|
self._param_store.add_previous_reduced_param(param)
|
||||||
|
|
||||||
current_bucket.reset_by_rank(reduce_rank)
|
current_bucket.reset_by_rank(reduce_rank)
|
||||||
|
|
||||||
def _reduce_grads_by_rank(self, reduce_rank, grads, bucket_size, dp_parallel_mode):
|
def _reduce_grads_by_rank(self, reduce_rank, grads, bucket_size, group_id, dp_parallel_mode):
|
||||||
grad_buckets_by_dtype = split_half_float_double(grads)
|
grad_buckets_by_dtype = split_half_float_double(grads)
|
||||||
next_bucket_list = []
|
next_bucket_list = []
|
||||||
# add parameters into bucket for reduction
|
# add parameters into bucket for reduction
|
||||||
|
@ -418,7 +433,9 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
for tensor in tensor_list:
|
for tensor in tensor_list:
|
||||||
param_bucket.add_to_bucket(tensor, allow_oversize=True)
|
param_bucket.add_to_bucket(tensor, allow_oversize=True)
|
||||||
if not param_bucket.is_empty():
|
if not param_bucket.is_empty():
|
||||||
self._reduce_and_copy(bucket=param_bucket, reduce_rank=reduce_rank, dp_parallel_mode=dp_parallel_mode)
|
self._reduce_and_copy(
|
||||||
|
bucket=param_bucket, reduce_rank=reduce_rank, group_id=group_id, dp_parallel_mode=dp_parallel_mode
|
||||||
|
)
|
||||||
next_bucket_list.append(param_bucket)
|
next_bucket_list.append(param_bucket)
|
||||||
|
|
||||||
# wait for the completion of previouce bucket list reduction, and do unflatten_and_copy()
|
# wait for the completion of previouce bucket list reduction, and do unflatten_and_copy()
|
||||||
|
@ -433,7 +450,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
# after the completion of bucket list reduction, add new buckets into _bucket_in_progress
|
# after the completion of bucket list reduction, add new buckets into _bucket_in_progress
|
||||||
self._bucket_in_progress = next_bucket_list.copy()
|
self._bucket_in_progress = next_bucket_list.copy()
|
||||||
|
|
||||||
def _reduce_and_copy(self, bucket: TensorBucket, reduce_rank, dp_parallel_mode):
|
def _reduce_and_copy(self, bucket: TensorBucket, reduce_rank, group_id, dp_parallel_mode):
|
||||||
# flatten the tensors and do allreduce
|
# flatten the tensors and do allreduce
|
||||||
bucket.flatten()
|
bucket.flatten()
|
||||||
bucket.commu_handle = reduce_tensor(
|
bucket.commu_handle = reduce_tensor(
|
||||||
|
@ -444,7 +461,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
)
|
)
|
||||||
|
|
||||||
# update the reduced tensor
|
# update the reduced tensor
|
||||||
if reduce_rank is None or reduce_rank == self._zero_local_rank:
|
if reduce_rank is None or reduce_rank == self._zero_local_rank[group_id]:
|
||||||
bucket.set_unflatten_and_copy_flag(flag=True)
|
bucket.set_unflatten_and_copy_flag(flag=True)
|
||||||
|
|
||||||
def _has_inf_or_nan(self, tensor):
|
def _has_inf_or_nan(self, tensor):
|
||||||
|
@ -473,8 +490,8 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
avg_gradients = self._grad_store._averaged_gradients
|
avg_gradients = self._grad_store._averaged_gradients
|
||||||
for group_id in range(self.num_param_groups):
|
for group_id in range(self.num_param_groups):
|
||||||
# the following operations are performed only on the rank to which parameters are assigned.
|
# the following operations are performed only on the rank to which parameters are assigned.
|
||||||
if self._zero_local_rank not in self.param_group_no_params_ranks[group_id]:
|
if self._zero_local_rank[group_id] not in self.param_group_no_params_ranks[group_id]:
|
||||||
param_group = self._param_store.get_fp16_params_by_rank_group(self._zero_local_rank, group_id)
|
param_group = self._param_store.get_fp16_params_by_rank_group(self._zero_local_rank[group_id], group_id)
|
||||||
|
|
||||||
if group_id not in avg_gradients:
|
if group_id not in avg_gradients:
|
||||||
avg_gradients[group_id] = []
|
avg_gradients[group_id] = []
|
||||||
|
@ -538,37 +555,11 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
parameters=params,
|
parameters=params,
|
||||||
last_stage=last_stage,
|
last_stage=last_stage,
|
||||||
previous_norm=previous_norm,
|
previous_norm=previous_norm,
|
||||||
|
zero_mode=self._broadcast_parallel_mode[group_id],
|
||||||
)
|
)
|
||||||
|
|
||||||
return norm
|
return norm
|
||||||
|
|
||||||
def _compute_norm_with_moe_group(self, group_id):
|
|
||||||
params = self._param_store.get_fp16_params_by_rank_group(group_id=group_id, rank=self._zero_local_rank)
|
|
||||||
# we do not get the average grad for moe parameters, so we have to constuct the gradients list here.
|
|
||||||
grads = [p.grad for p in params]
|
|
||||||
|
|
||||||
if len(params) == 0:
|
|
||||||
grads = [self.padding_grad]
|
|
||||||
params = [self.padding_tensor]
|
|
||||||
|
|
||||||
norm = 0
|
|
||||||
if self._clip_grad_norm > 0:
|
|
||||||
norm = compute_norm(
|
|
||||||
gradients=grads,
|
|
||||||
parameters=params,
|
|
||||||
last_stage=True,
|
|
||||||
is_moe_group=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Need to allreduce(avg) the norms across different ranks because moe params will not be synced during allreduce
|
|
||||||
# model and zero have been reduced!!!
|
|
||||||
pg = gpc.get_group(ParallelMode.DATA)
|
|
||||||
scaled_norm = norm * 1.0 / float(gpc.get_world_size(ParallelMode.DATA))
|
|
||||||
scaled_norm_tensor = torch.tensor(scaled_norm, device=get_current_device(), dtype=torch.float)
|
|
||||||
dist.all_reduce(scaled_norm_tensor, group=pg)
|
|
||||||
all_groups_norm = scaled_norm_tensor.item()
|
|
||||||
return all_groups_norm
|
|
||||||
|
|
||||||
@llm_timeout(func_name="optim_step")
|
@llm_timeout(func_name="optim_step")
|
||||||
def step(self, closure=None):
|
def step(self, closure=None):
|
||||||
"""Performs a single optimization step.
|
"""Performs a single optimization step.
|
||||||
|
@ -591,16 +582,13 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
self._store_and_try_reduce_grads_by_bucket(param)
|
self._store_and_try_reduce_grads_by_bucket(param)
|
||||||
|
|
||||||
# we need to reduce the gradients left in the communication bucket
|
# we need to reduce the gradients left in the communication bucket
|
||||||
self._reduce_grads_stored_in_bucket(self._non_moe_bucket_store, reduce_rank=None, last_bucket=True)
|
for group_id in range(self.num_param_groups):
|
||||||
self._reduce_grads_stored_in_bucket(self._moe_bucket_store, reduce_rank=None, last_bucket=True)
|
self._reduce_grads_stored_in_bucket(self._bucket_store[group_id], reduce_rank=None, last_bucket=True)
|
||||||
|
|
||||||
# compute norm for gradients in the before bucket
|
# compute norm for gradients in the before bucket
|
||||||
groups_norms = []
|
groups_norms = []
|
||||||
for group_id in range(self.num_param_groups):
|
for group_id in range(self.num_param_groups):
|
||||||
if self._is_moe_group(self.optim.param_groups[group_id]):
|
groups_norms.append(self._compute_norm_with_stage(group_id=group_id))
|
||||||
groups_norms.append(None)
|
|
||||||
else:
|
|
||||||
groups_norms.append(self._compute_norm_with_stage(group_id=group_id))
|
|
||||||
|
|
||||||
# clear reduced grads
|
# clear reduced grads
|
||||||
# grads in the last bucket is reduced
|
# grads in the last bucket is reduced
|
||||||
|
@ -616,15 +604,22 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
for group_id in range(self.num_param_groups):
|
for group_id in range(self.num_param_groups):
|
||||||
group_name = self.param_groups[group_id]["name"] if "name" in self.param_groups[group_id] else "default"
|
group_name = self.param_groups[group_id]["name"] if "name" in self.param_groups[group_id] else "default"
|
||||||
group_name = f"{group_id}_{group_name}"
|
group_name = f"{group_id}_{group_name}"
|
||||||
|
total_norms[group_name] = self._compute_norm_with_stage(
|
||||||
|
group_id=group_id,
|
||||||
|
last_bucket=True,
|
||||||
|
last_stage=True,
|
||||||
|
previous_norm=groups_norms[group_id],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Need to allreduce(avg) the norms across different ranks because moe params will not be synced
|
||||||
|
# during allreduce
|
||||||
if self._is_moe_group(self.optim.param_groups[group_id]):
|
if self._is_moe_group(self.optim.param_groups[group_id]):
|
||||||
total_norms[group_name] = self._compute_norm_with_moe_group(group_id=group_id)
|
# model and zero have been reduced!!!
|
||||||
else:
|
pg = gpc.get_group(ParallelMode.EXPERT)
|
||||||
total_norms[group_name] = self._compute_norm_with_stage(
|
scaled_norm = total_norms[group_name] * 1.0 / float(gpc.get_world_size(ParallelMode.EXPERT))
|
||||||
group_id=group_id,
|
scaled_norm_tensor = torch.tensor(scaled_norm, device=get_current_device(), dtype=torch.float)
|
||||||
last_bucket=True,
|
dist.all_reduce(scaled_norm_tensor, group=pg)
|
||||||
last_stage=True,
|
total_norms[group_name] = scaled_norm_tensor.item()
|
||||||
previous_norm=groups_norms[group_id],
|
|
||||||
)
|
|
||||||
|
|
||||||
timer("sync_grad").start()
|
timer("sync_grad").start()
|
||||||
self._sync_grad()
|
self._sync_grad()
|
||||||
|
@ -746,7 +741,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
for group_id in range(len(self._fp16_param_groups)):
|
for group_id in range(len(self._fp16_param_groups)):
|
||||||
if self.param_group_has_params[group_id]:
|
if self.param_group_has_params[group_id]:
|
||||||
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(
|
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(
|
||||||
rank=self._zero_local_rank, group_id=group_id
|
rank=self._zero_local_rank[group_id], group_id=group_id
|
||||||
)
|
)
|
||||||
fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id]
|
fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id]
|
||||||
fp16_param.data.copy_(fp32_param)
|
fp16_param.data.copy_(fp32_param)
|
||||||
|
@ -766,27 +761,26 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
def broadcast_params(self):
|
def broadcast_params(self):
|
||||||
handles = []
|
handles = []
|
||||||
|
|
||||||
for rank, group_id in product(range(self._zero_world_size), range(self.num_param_groups)):
|
for group_id in range(self.num_param_groups):
|
||||||
if self._is_moe_group(self.optim.param_groups[group_id]):
|
for rank in range(self._zero_world_size[group_id]):
|
||||||
continue
|
# The following operations are performed only on the rank to which parameters are assigned.
|
||||||
# The following operations are performed only on the rank to which parameters are assigned.
|
if rank in self.param_group_no_params_ranks[group_id]:
|
||||||
if rank in self.param_group_no_params_ranks[group_id]:
|
continue
|
||||||
continue
|
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(rank=rank, group_id=group_id)
|
||||||
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(rank=rank, group_id=group_id)
|
# grank = gpc.get_ranks_in_group(group_type)[rank] # need to convert to the global rank
|
||||||
# grank = gpc.get_ranks_in_group(group_type)[rank] # need to convert to the global rank
|
# assert grank == rank, f"{grank} == {rank}"
|
||||||
# assert grank == rank, f"{grank} == {rank}"
|
g_rank = gpc.get_ranks_in_group(self._broadcast_parallel_mode[group_id])[rank]
|
||||||
g_rank = gpc.get_ranks_in_group(self._broadcast_parallel_mode)[rank]
|
handle = dist.broadcast(
|
||||||
handle = dist.broadcast(
|
fp16_param,
|
||||||
fp16_param,
|
src=g_rank,
|
||||||
src=g_rank,
|
group=gpc.get_group(self._broadcast_parallel_mode[group_id]),
|
||||||
group=gpc.get_group(ParallelMode.ZERO1),
|
async_op=True,
|
||||||
async_op=True,
|
)
|
||||||
)
|
|
||||||
|
|
||||||
if self._overlap_sync_param:
|
if self._overlap_sync_param:
|
||||||
self._param_bcast_sync_handler.add_bcast_handle(rank, handle)
|
self._param_bcast_sync_handler.add_bcast_handle(rank, handle)
|
||||||
else:
|
else:
|
||||||
handles.append(handle)
|
handles.append(handle)
|
||||||
|
|
||||||
for handle in handles:
|
for handle in handles:
|
||||||
handle.wait()
|
handle.wait()
|
||||||
|
@ -802,7 +796,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
# check for overflow
|
# check for overflow
|
||||||
for group_id in range(len(self._fp16_param_groups)):
|
for group_id in range(len(self._fp16_param_groups)):
|
||||||
# The following operations are performed only on the rank to which parameters are assigned.
|
# The following operations are performed only on the rank to which parameters are assigned.
|
||||||
if self._zero_local_rank not in self.param_group_no_params_ranks[group_id]:
|
if self._zero_local_rank[group_id] not in self.param_group_no_params_ranks[group_id]:
|
||||||
for avg_grad in self._grad_store.get_averaged_gradients_by_group(group_id):
|
for avg_grad in self._grad_store.get_averaged_gradients_by_group(group_id):
|
||||||
if avg_grad is not None and has_inf_or_nan(avg_grad):
|
if avg_grad is not None and has_inf_or_nan(avg_grad):
|
||||||
self._found_overflow.fill_(1.0)
|
self._found_overflow.fill_(1.0)
|
||||||
|
@ -843,7 +837,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
|
|
||||||
flat_fp32_weights = {}
|
flat_fp32_weights = {}
|
||||||
for group_id, param in self._fp32_flat_param_groups_of_current_rank.items():
|
for group_id, param in self._fp32_flat_param_groups_of_current_rank.items():
|
||||||
if self._zero_local_rank not in self.param_group_no_params_ranks[group_id]:
|
if self._zero_local_rank[group_id] not in self.param_group_no_params_ranks[group_id]:
|
||||||
assert param.grad is None
|
assert param.grad is None
|
||||||
flat_fp32_weights[group_id] = param
|
flat_fp32_weights[group_id] = param
|
||||||
states["flat_fp32_weights"] = flat_fp32_weights
|
states["flat_fp32_weights"] = flat_fp32_weights
|
||||||
|
@ -863,7 +857,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
flat_fp32_weights = states["flat_fp32_weights"]
|
flat_fp32_weights = states["flat_fp32_weights"]
|
||||||
assert set(flat_fp32_weights.keys()) == set(self._fp32_flat_param_groups_of_current_rank)
|
assert set(flat_fp32_weights.keys()) == set(self._fp32_flat_param_groups_of_current_rank)
|
||||||
for group_id, param in flat_fp32_weights.items():
|
for group_id, param in flat_fp32_weights.items():
|
||||||
if self._zero_local_rank not in self.param_group_no_params_ranks[group_id]:
|
if self._zero_local_rank[group_id] not in self.param_group_no_params_ranks[group_id]:
|
||||||
self_param = self._fp32_flat_param_groups_of_current_rank[group_id]
|
self_param = self._fp32_flat_param_groups_of_current_rank[group_id]
|
||||||
assert (
|
assert (
|
||||||
self_param.shape == param.shape
|
self_param.shape == param.shape
|
||||||
|
@ -872,9 +866,9 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
|
|
||||||
# Load the fp16 model weights.
|
# Load the fp16 model weights.
|
||||||
for group_id in range(len(self._fp16_param_groups)):
|
for group_id in range(len(self._fp16_param_groups)):
|
||||||
if self._zero_local_rank not in self.param_group_no_params_ranks[group_id]:
|
if self._zero_local_rank[group_id] not in self.param_group_no_params_ranks[group_id]:
|
||||||
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(
|
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(
|
||||||
rank=self._zero_local_rank, group_id=group_id
|
rank=self._zero_local_rank[group_id], group_id=group_id
|
||||||
)
|
)
|
||||||
fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id]
|
fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id]
|
||||||
fp16_param.data.copy_(fp32_param)
|
fp16_param.data.copy_(fp32_param)
|
||||||
|
@ -891,7 +885,7 @@ def reload_zero_fp32_buff(optimizer):
|
||||||
if optimizer.param_group_has_params[group_id]:
|
if optimizer.param_group_has_params[group_id]:
|
||||||
# flatten fp16 params have already been updated by 'load_model_checkpoint'
|
# flatten fp16 params have already been updated by 'load_model_checkpoint'
|
||||||
fp16_flat_current_rank = optimizer._param_store.get_flat_fp16_param_by_rank_group(
|
fp16_flat_current_rank = optimizer._param_store.get_flat_fp16_param_by_rank_group(
|
||||||
optimizer._zero_local_rank, group_id
|
optimizer._zero_local_rank[group_id], group_id
|
||||||
)
|
)
|
||||||
# param_group["params"] is fp32 flatten optimizer states of this zero rank.
|
# param_group["params"] is fp32 flatten optimizer states of this zero rank.
|
||||||
param_group["params"][0].data.copy_(fp16_flat_current_rank.float())
|
param_group["params"][0].data.copy_(fp16_flat_current_rank.float())
|
||||||
|
|
|
@ -33,18 +33,22 @@ class BucketStore(BaseStore):
|
||||||
Bucket Store
|
Bucket Store
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, dp_parallel_mode):
|
def __init__(self, group_id, dp_parallel_mode):
|
||||||
super().__init__(dp_parallel_mode)
|
super().__init__(dp_parallel_mode)
|
||||||
self._grads = dict()
|
self._grads = dict()
|
||||||
self._params = dict()
|
self._params = dict()
|
||||||
self._num_elements_in_bucket = dict()
|
self._num_elements_in_bucket = dict()
|
||||||
self._dp_parallel_mode = dp_parallel_mode
|
self._dp_parallel_mode = dp_parallel_mode
|
||||||
|
self._group_id = group_id
|
||||||
|
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
def num_elements_in_bucket(self, reduce_rank: int = None):
|
def num_elements_in_bucket(self, reduce_rank: int = None):
|
||||||
return self._num_elements_in_bucket[reduce_rank]
|
return self._num_elements_in_bucket[reduce_rank]
|
||||||
|
|
||||||
|
def get_param_group_id(self):
|
||||||
|
return self._group_id
|
||||||
|
|
||||||
def get_dp_parallel_mode(self):
|
def get_dp_parallel_mode(self):
|
||||||
return self._dp_parallel_mode
|
return self._dp_parallel_mode
|
||||||
|
|
||||||
|
@ -182,20 +186,6 @@ class ParameterStore(BaseStore):
|
||||||
"""
|
"""
|
||||||
return self._fp16_param_to_rank[tensor]
|
return self._fp16_param_to_rank[tensor]
|
||||||
|
|
||||||
def belongs_to_current_rank(self, tensor) -> bool:
|
|
||||||
"""
|
|
||||||
Check whether a parameter is supposed to be updated by the process of the current rank
|
|
||||||
|
|
||||||
:param tensor: A :class:`torch.Tensor` object
|
|
||||||
:type tensor: torch.Tensor
|
|
||||||
|
|
||||||
:return: True if the parameter should be updated by the current rank. Otherwise false.
|
|
||||||
:rtype: bool
|
|
||||||
"""
|
|
||||||
|
|
||||||
tensor_rank = self._fp16_param_to_rank[tensor]
|
|
||||||
return tensor_rank == self._local_rank
|
|
||||||
|
|
||||||
def add_fp16_param_list_by_rank_group(self, rank, group_id, tensor_list) -> None:
|
def add_fp16_param_list_by_rank_group(self, rank, group_id, tensor_list) -> None:
|
||||||
if rank not in self._rank_groupid_to_fp16_param_list:
|
if rank not in self._rank_groupid_to_fp16_param_list:
|
||||||
self._rank_groupid_to_fp16_param_list[rank] = dict()
|
self._rank_groupid_to_fp16_param_list[rank] = dict()
|
||||||
|
|
|
@ -209,7 +209,9 @@ def calc_lp(grads, norm_type):
|
||||||
return norm
|
return norm
|
||||||
|
|
||||||
|
|
||||||
def compute_norm(gradients, parameters, last_stage=False, previous_norm=None, norm_type=2, is_moe_group=False):
|
def compute_norm(
|
||||||
|
gradients, parameters, last_stage=False, previous_norm=None, norm_type=2, zero_mode=ParallelMode.ZERO1
|
||||||
|
):
|
||||||
"""Get the norm
|
"""Get the norm
|
||||||
Arguments:
|
Arguments:
|
||||||
gradients (Iterable[Tensor]): The gradient value.
|
gradients (Iterable[Tensor]): The gradient value.
|
||||||
|
@ -302,8 +304,7 @@ def compute_norm(gradients, parameters, last_stage=False, previous_norm=None, no
|
||||||
|
|
||||||
# This is because we use zero1, so we need to use this reduction.
|
# This is because we use zero1, so we need to use this reduction.
|
||||||
# TODO: Check zero group to be a subset of dp group.
|
# TODO: Check zero group to be a subset of dp group.
|
||||||
if not is_moe_group:
|
dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(zero_mode))
|
||||||
dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.ZERO1))
|
|
||||||
|
|
||||||
if torch.is_tensor(total_norm):
|
if torch.is_tensor(total_norm):
|
||||||
total_norm = total_norm.item()
|
total_norm = total_norm.item()
|
||||||
|
|
|
@ -2,6 +2,7 @@ from typing import Dict, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from internlm.core.context.parallel_context import ParallelMode
|
||||||
from internlm.core.context.parallel_context import global_context as gpc
|
from internlm.core.context.parallel_context import global_context as gpc
|
||||||
from internlm.model.utils import is_gate_param, is_moe_param, is_norm_param
|
from internlm.model.utils import is_gate_param, is_moe_param, is_norm_param
|
||||||
|
|
||||||
|
@ -37,13 +38,13 @@ def split_params_into_different_groups_for_optimizer(param_groups: Tuple[Dict])
|
||||||
|
|
||||||
# create new groups for fp32, norm, moe gate and moe expert
|
# create new groups for fp32, norm, moe gate and moe expert
|
||||||
new_groups = {}
|
new_groups = {}
|
||||||
new_groups["fp32"] = {"name": "fp32", "params": []}
|
new_groups["fp32"] = {"name": "fp32", "params": [], "dp_mode": ParallelMode.DATA}
|
||||||
if gpc.config.model.get("num_experts", 0) > 1:
|
if gpc.config.model.get("num_experts", 0) > 1:
|
||||||
# norm and gate are special group to force sync (when enable MoE).
|
# norm and gate are special group to force sync (when enable MoE).
|
||||||
for key in ["gate", "norm"]:
|
for key in ["gate", "norm"]:
|
||||||
new_groups[key] = {"name": key, key: True, "params": []}
|
new_groups[key] = {"name": key, key: True, "params": [], "dp_mode": ParallelMode.DATA}
|
||||||
for key in gpc.expert_parallel_group_names:
|
for key in gpc.expert_parallel_group_names:
|
||||||
new_groups[key] = {"name": key, "moe": True, "params": []}
|
new_groups[key] = {"name": key, "moe": True, "params": [], "dp_mode": ParallelMode.EXPERT_DATA}
|
||||||
|
|
||||||
for pgroup in param_groups:
|
for pgroup in param_groups:
|
||||||
# copy attribute from origin group, we assume the input param_groups only
|
# copy attribute from origin group, we assume the input param_groups only
|
||||||
|
@ -72,6 +73,7 @@ def split_params_into_different_groups_for_optimizer(param_groups: Tuple[Dict])
|
||||||
|
|
||||||
# bf16 param group, which is the first group in the param groups
|
# bf16 param group, which is the first group in the param groups
|
||||||
pgroup["params"] = origin_params
|
pgroup["params"] = origin_params
|
||||||
|
pgroup["dp_mode"] = ParallelMode.DATA
|
||||||
|
|
||||||
# param groups may contain empty groups, such as fp32
|
# param groups may contain empty groups, such as fp32
|
||||||
param_groups.extend(new_groups.values())
|
param_groups.extend(new_groups.values())
|
||||||
|
|
|
@ -392,13 +392,14 @@ def save_optimizer_checkpoint(optim, state_path):
|
||||||
zero_rank = gpc.get_local_rank(ParallelMode.ZERO1)
|
zero_rank = gpc.get_local_rank(ParallelMode.ZERO1)
|
||||||
tp_rank = gpc.get_local_rank(ParallelMode.TENSOR)
|
tp_rank = gpc.get_local_rank(ParallelMode.TENSOR)
|
||||||
pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||||
|
zero_size = gpc.get_world_size(ParallelMode.ZERO1)
|
||||||
tp_size = gpc.get_world_size(ParallelMode.TENSOR)
|
tp_size = gpc.get_world_size(ParallelMode.TENSOR)
|
||||||
pp_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
pp_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
||||||
fp = f"optimizer_tp{tp_rank}_pp{pp_rank}_zo{zero_rank}.pt"
|
fp = f"optimizer_tp{tp_rank}_pp{pp_rank}_zo{zero_rank}.pt"
|
||||||
|
|
||||||
states = optim.state_dict()
|
states = optim.state_dict()
|
||||||
if isinstance(optim, HybridZeroOptimizer):
|
if isinstance(optim, HybridZeroOptimizer):
|
||||||
if gpc.get_global_rank() < optim.zero_world_size * tp_size * pp_size:
|
if gpc.get_global_rank() < zero_size * tp_size * pp_size:
|
||||||
llm_save(os.path.join(state_path, fp), states)
|
llm_save(os.path.join(state_path, fp), states)
|
||||||
if "zero_devide_optim_plan" in states:
|
if "zero_devide_optim_plan" in states:
|
||||||
params_per_rank_id_dict = states.pop("zero_devide_optim_plan")
|
params_per_rank_id_dict = states.pop("zero_devide_optim_plan")
|
||||||
|
|
|
@ -5,7 +5,6 @@ import torch.distributed as dist
|
||||||
|
|
||||||
from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode
|
from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode
|
||||||
from internlm.core.context import global_context as gpc
|
from internlm.core.context import global_context as gpc
|
||||||
from internlm.model.utils import is_moe_param
|
|
||||||
|
|
||||||
|
|
||||||
def is_model_parallel_parameter(p):
|
def is_model_parallel_parameter(p):
|
||||||
|
@ -23,7 +22,7 @@ def sync_model_param(model):
|
||||||
gpc.is_initialized(ParallelMode.EXPERT_DATA) and gpc.get_world_size(ParallelMode.EXPERT_DATA) > 1
|
gpc.is_initialized(ParallelMode.EXPERT_DATA) and gpc.get_world_size(ParallelMode.EXPERT_DATA) > 1
|
||||||
)
|
)
|
||||||
for param in model.parameters():
|
for param in model.parameters():
|
||||||
if sync_moe_param and is_moe_param(param):
|
if sync_moe_param and getattr(param, "is_expert", False):
|
||||||
ranks = gpc.get_ranks_in_group(ParallelMode.EXPERT_DATA)
|
ranks = gpc.get_ranks_in_group(ParallelMode.EXPERT_DATA)
|
||||||
dist.broadcast(param, src=ranks[0], group=gpc.get_group(ParallelMode.EXPERT_DATA))
|
dist.broadcast(param, src=ranks[0], group=gpc.get_group(ParallelMode.EXPERT_DATA))
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -86,13 +86,13 @@ ckpt_config_list = [
|
||||||
def overwrite_optim_state(optim, set_value):
|
def overwrite_optim_state(optim, set_value):
|
||||||
if isinstance(optim, HybridZeroOptimizer):
|
if isinstance(optim, HybridZeroOptimizer):
|
||||||
for group_id, p in optim._fp32_flat_param_groups_of_current_rank.items():
|
for group_id, p in optim._fp32_flat_param_groups_of_current_rank.items():
|
||||||
if optim._zero_local_rank not in optim.param_group_no_params_ranks[group_id]:
|
if optim._zero_local_rank[group_id] not in optim.param_group_no_params_ranks[group_id]:
|
||||||
# p.copy_(torch.full_like(p, set_value, dtype=p.dtype))
|
# p.copy_(torch.full_like(p, set_value, dtype=p.dtype))
|
||||||
p.data.fill_(set_value)
|
p.data.fill_(set_value)
|
||||||
for group_id in range(len(optim._fp16_param_groups)):
|
for group_id in range(len(optim._fp16_param_groups)):
|
||||||
if optim._zero_local_rank not in optim.param_group_no_params_ranks[group_id]:
|
if optim._zero_local_rank[group_id] not in optim.param_group_no_params_ranks[group_id]:
|
||||||
fp16_p = optim._param_store.get_flat_fp16_param_by_rank_group(
|
fp16_p = optim._param_store.get_flat_fp16_param_by_rank_group(
|
||||||
rank=optim._zero_local_rank, group_id=group_id
|
rank=optim._zero_local_rank[group_id], group_id=group_id
|
||||||
)
|
)
|
||||||
fp16_p.fill_(set_value)
|
fp16_p.fill_(set_value)
|
||||||
else:
|
else:
|
||||||
|
@ -109,7 +109,7 @@ def compare_optim_state(optim1, optim2):
|
||||||
fp32_buff2 = optim2._fp32_flat_param_groups_of_current_rank
|
fp32_buff2 = optim2._fp32_flat_param_groups_of_current_rank
|
||||||
for group_id_1, group_id_2 in zip(fp32_buff1, fp32_buff2):
|
for group_id_1, group_id_2 in zip(fp32_buff1, fp32_buff2):
|
||||||
re &= group_id_1 == group_id_2
|
re &= group_id_1 == group_id_2
|
||||||
if optim1.zero_local_rank not in optim1.param_group_no_params_ranks[group_id_1]:
|
if optim1.zero_local_rank[group_id_1] not in optim1.param_group_no_params_ranks[group_id_1]:
|
||||||
re &= torch.equal(fp32_buff1[group_id_1], fp32_buff1[group_id_2])
|
re &= torch.equal(fp32_buff1[group_id_1], fp32_buff1[group_id_2])
|
||||||
else:
|
else:
|
||||||
for group1, group2 in zip(optim1.param_groups, optim2.param_groups):
|
for group1, group2 in zip(optim1.param_groups, optim2.param_groups):
|
||||||
|
@ -122,12 +122,12 @@ def compare_optim_value(optim, value):
|
||||||
re = True
|
re = True
|
||||||
if isinstance(optim, HybridZeroOptimizer):
|
if isinstance(optim, HybridZeroOptimizer):
|
||||||
for group_id, p in optim._fp32_flat_param_groups_of_current_rank.items():
|
for group_id, p in optim._fp32_flat_param_groups_of_current_rank.items():
|
||||||
if optim._zero_local_rank not in optim.param_group_no_params_ranks[group_id]:
|
if optim._zero_local_rank[group_id] not in optim.param_group_no_params_ranks[group_id]:
|
||||||
re &= torch.equal(p, torch.full_like(p, value, dtype=p.dtype))
|
re &= torch.equal(p, torch.full_like(p, value, dtype=p.dtype))
|
||||||
for group_id in range(len(optim._fp16_param_groups)):
|
for group_id in range(len(optim._fp16_param_groups)):
|
||||||
if optim._zero_local_rank not in optim.param_group_no_params_ranks[group_id]:
|
if optim._zero_local_rank[group_id] not in optim.param_group_no_params_ranks[group_id]:
|
||||||
fp16_p = optim._param_store.get_flat_fp16_param_by_rank_group(
|
fp16_p = optim._param_store.get_flat_fp16_param_by_rank_group(
|
||||||
rank=optim._zero_local_rank, group_id=group_id
|
rank=optim._zero_local_rank[group_id], group_id=group_id
|
||||||
)
|
)
|
||||||
re &= torch.equal(fp16_p, torch.full_like(fp16_p, value, dtype=fp16_p.dtype))
|
re &= torch.equal(fp16_p, torch.full_like(fp16_p, value, dtype=fp16_p.dtype))
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Reference in New Issue