mirror of https://github.com/InternLM/InternLM
feat(optimizer/hybrid_zero_optim.py): resolve conflicts
commit
3c6925499f
|
@ -194,7 +194,6 @@ class NonPipelineScheduler(BaseScheduler):
|
|||
_output, _loss, _moe_loss = self._train_one_batch(
|
||||
_data, _label, engine, forward_only, return_loss, self._grad_accum_size
|
||||
)
|
||||
engine.optimizer.reset_reduce_bucket()
|
||||
|
||||
if return_loss:
|
||||
loss += _loss
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
|
||||
import math
|
||||
from functools import partial
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
@ -105,8 +106,8 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
# it will not manage the tensors used by mixed precision training
|
||||
self._param_store = ParameterStore(ParallelMode.ZERO1)
|
||||
self._grad_store = GradientStore(ParallelMode.DATA)
|
||||
self._bucket_store = []
|
||||
self._bucket_store_2 = []
|
||||
self._bucket_store: List[BucketStore] = []
|
||||
self._accum_grad_buckets: List[BucketStore] = []
|
||||
self._bucket_in_progress = []
|
||||
|
||||
# fp16 and fp32 params for mixed precision training
|
||||
|
@ -175,7 +176,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
# TODO _broadcast_parallel_mode is not only used in broadcast, maybe can change its name
|
||||
self._broadcast_parallel_mode.append(zero_mode)
|
||||
self._bucket_store.append(BucketStore(group_id, param_group["dp_mode"]))
|
||||
self._bucket_store_2.append(BucketStore(group_id, param_group["dp_mode"]))
|
||||
self._accum_grad_buckets.append(BucketStore(group_id, param_group["dp_mode"]))
|
||||
|
||||
# assign parameters to ranks the params in the list are sorted
|
||||
params_per_rank, no_params_ranks = self._partition_param_list(group_id, param_group)
|
||||
|
@ -323,7 +324,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
)
|
||||
|
||||
reduce_scatter_checker = partial(
|
||||
self._wait_reduce_scatter_and_accumulate_grad,
|
||||
self._wait_reduce_scatter_and_accumulate_grads,
|
||||
param=param,
|
||||
reduce_rank=reduce_rank,
|
||||
)
|
||||
|
@ -332,7 +333,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
# NOT IMPORTANT BUT GOOD TO KNOW:
|
||||
# args here is not grad, but allow_unreacable and accumulate_grad
|
||||
def reduce_grad_hook(*args): # pylint: disable=W0613
|
||||
if gpc.config.fstp_handler is not None:
|
||||
if self._fstp_handler is not None:
|
||||
reduce_scatter_checker()
|
||||
|
||||
if self.skip_grad_reduce is False:
|
||||
|
@ -356,84 +357,36 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
group_id = getattr(param, "group_id")
|
||||
return tensor_rank == gpc.get_local_rank(self._broadcast_parallel_mode[group_id])
|
||||
|
||||
def reset_reduce_bucket(self) -> None:
|
||||
for bucket in self._bucket_store_2:
|
||||
for rank, params in bucket._params.items():
|
||||
for _param in params:
|
||||
if not hasattr(_param, "_fstp_reduce_scatter_str"):
|
||||
continue
|
||||
def _accum_grads_store_in_bucket(self, bucket: BucketStore, reduce_rank: Optional[int] = None) -> None:
|
||||
for _param in bucket.get_param(reduce_rank):
|
||||
if not hasattr(_param, "_fstp_reduce_scatter_str"):
|
||||
continue
|
||||
|
||||
key = getattr(_param, "_fstp_reduce_scatter_str")
|
||||
comm_handle, _grad = self._fstp_handler.reduce_scatter_handlers[key]
|
||||
comm_handle.wait()
|
||||
_param.grad.add_(_grad)
|
||||
# self._fstp_handler.reduce_scatter_handlers[key] = None
|
||||
# del _grad
|
||||
release_reduce_scatter_memory_pool(size=tuple(_grad.size()), index=_grad.index)
|
||||
del self._fstp_handler.reduce_scatter_handlers[key]
|
||||
self._fstp_handler.reduce_scatter_handlers[key] = None
|
||||
assert key in self._fstp_handler.reduce_scatter_handlers
|
||||
# if not hasattr(_param, "_fstp_all_reduce_str"):
|
||||
# continue
|
||||
# wait and accumulate gardient.
|
||||
_key = getattr(_param, "_fstp_reduce_scatter_str")
|
||||
_comm_handle, _grad = self._fstp_handler.reduce_scatter_handlers[_key]
|
||||
_comm_handle.wait()
|
||||
_param.grad.add_(_grad)
|
||||
|
||||
# key = getattr(_param, "_fstp_all_reduce_str")
|
||||
# comm_handle, _grad = self._fstp_handler.all_reduce_handlers[key]
|
||||
# comm_handle.wait()
|
||||
# with torch.no_grad():
|
||||
# _grad = split_forward_gather_backward(_grad, ParallelMode.TENSOR, dim=0)
|
||||
# _param.grad.add_(_grad)
|
||||
# # self._fstp_handler.reduce_scatter_handlers[key] = None
|
||||
# del _grad
|
||||
# del self._fstp_handler.all_reduce_handlers[key]
|
||||
# self._fstp_handler.all_reduce_handlers[key] = None
|
||||
# assert key in self._fstp_handler.all_reduce_handlers
|
||||
# release cuda memory.
|
||||
release_reduce_scatter_memory_pool(size=tuple(_grad.size()), index=_grad.index)
|
||||
self._fstp_handler.reduce_scatter_handlers[_key] = None
|
||||
|
||||
bucket.reset_by_rank(rank)
|
||||
bucket.reset_by_rank(reduce_rank)
|
||||
|
||||
def _wait_reduce_scatter_and_accumulate_grad(self, param, reduce_rank=None):
|
||||
def _wait_reduce_scatter_and_accumulate_grads(self, param, reduce_rank: Optional[int] = None):
|
||||
param_size = param.numel()
|
||||
|
||||
group_id = getattr(param, "group_id")
|
||||
current_bucket = self._accum_grad_buckets[group_id]
|
||||
|
||||
# check if the bucket is full
|
||||
# if full, will reduce the grads already in the bucket
|
||||
# after reduction, the bucket will be empty
|
||||
group_id = getattr(param, "group_id")
|
||||
current_bucket = self._bucket_store_2[group_id]
|
||||
|
||||
if current_bucket.num_elements_in_bucket(reduce_rank) >= 512 * 1024 * 1024:
|
||||
# wait reduce scatter communication
|
||||
params = current_bucket.get_param(reduce_rank)
|
||||
for _param in params:
|
||||
if not hasattr(_param, "_fstp_reduce_scatter_str"):
|
||||
continue
|
||||
|
||||
key = getattr(_param, "_fstp_reduce_scatter_str")
|
||||
comm_handle, _grad = self._fstp_handler.reduce_scatter_handlers[key]
|
||||
comm_handle.wait()
|
||||
_param.grad.add_(_grad)
|
||||
# self._fstp_handler.reduce_scatter_handlers[key] = None
|
||||
# del _grad
|
||||
release_reduce_scatter_memory_pool(size=tuple(_grad.size()), index=_grad.index)
|
||||
del self._fstp_handler.reduce_scatter_handlers[key]
|
||||
self._fstp_handler.reduce_scatter_handlers[key] = None
|
||||
assert key in self._fstp_handler.reduce_scatter_handlers
|
||||
|
||||
# if not hasattr(_param, "_fstp_all_reduce_str"):
|
||||
# continue
|
||||
|
||||
# key = getattr(_param, "_fstp_all_reduce_str")
|
||||
# comm_handle, _grad = self._fstp_handler.all_reduce_handlers[key]
|
||||
# comm_handle.wait()
|
||||
# with torch.no_grad():
|
||||
# _grad = split_forward_gather_backward(_grad, ParallelMode.TENSOR, dim=0)
|
||||
# _param.grad.add_(_grad)
|
||||
# # self._fstp_handler.reduce_scatter_handlers[key] = None
|
||||
# del _grad
|
||||
# del self._fstp_handler.all_reduce_handlers[key]
|
||||
# self._fstp_handler.all_reduce_handlers[key] = None
|
||||
# assert key in self._fstp_handler.all_reduce_handlers
|
||||
|
||||
current_bucket.reset_by_rank(reduce_rank)
|
||||
if current_bucket.num_elements_in_bucket(reduce_rank) >= self._reduce_bucket_size:
|
||||
self._accum_grads_store_in_bucket(current_bucket, reduce_rank)
|
||||
|
||||
# otherwise, add the parameter into bucket.
|
||||
current_bucket.add_num_elements_in_bucket(param_size, reduce_rank)
|
||||
current_bucket.add_param(param, reduce_rank)
|
||||
|
||||
|
@ -661,6 +614,10 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
for group_id in range(self.num_param_groups):
|
||||
self._reduce_grads_stored_in_bucket(self._bucket_store[group_id], reduce_rank=None, last_bucket=True)
|
||||
|
||||
# we need to accumulate gradients left in the accumulate gardient bucket
|
||||
for group_id in range(self.num_param_groups):
|
||||
self._accum_grads_store_in_bucket(self._accum_grad_buckets[group_id], reduce_rank=None)
|
||||
|
||||
# compute norm for gradients in the before bucket
|
||||
groups_norms = []
|
||||
for group_id in range(self.num_param_groups):
|
||||
|
|
Loading…
Reference in New Issue