update optimizer accumulate grad impl when fstp

pull/456/head
chenxun.p 2023-10-20 15:58:06 +08:00
parent 815a584930
commit 95488d8e8f
2 changed files with 51 additions and 83 deletions

View File

@ -194,7 +194,6 @@ class NonPipelineScheduler(BaseScheduler):
_output, _loss, _moe_loss = self._train_one_batch(
_data, _label, engine, forward_only, return_loss, self._grad_accum_size
)
engine.optimizer.reset_reduce_bucket()
if return_loss:
loss += _loss

View File

@ -2,6 +2,7 @@
# -*- encoding: utf-8 -*-
import math
from typing import Optional, List
from functools import partial
import torch
@ -40,8 +41,20 @@ from .utils import compute_norm
inf = math.inf
logger = get_logger(__file__)
def print_memory(msg):
print(msg, " rank = ", gpc.get_global_rank(), " memory allocated: ", torch.cuda.memory_allocated() / 1024 / 1024 / 1024, " reverved memory: ", torch.cuda.memory_reserved() / 1024 / 1024 / 1024, " max memory: ", torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024, flush=True)
print(
msg,
" rank = ",
gpc.get_global_rank(),
" memory allocated: ",
torch.cuda.memory_allocated() / 1024 / 1024 / 1024,
" reverved memory: ",
torch.cuda.memory_reserved() / 1024 / 1024 / 1024,
" max memory: ",
torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024,
flush=True,
)
print("===========================================")
@ -69,7 +82,7 @@ class HybridZeroOptimizer(BaseOptimizer):
backoff_factor = grad_scal_cfg.backoff_factor
hysteresis = grad_scal_cfg.hysteresis
max_scale = grad_scal_cfg.max_scale
if gpc.config.parallel["tensor"]["mode"] == "fstp" and gpc.config.parallel["tensor"]["overlap"] == True:
self._fstp_handler = gpc.config.fstp_handler
@ -90,8 +103,8 @@ class HybridZeroOptimizer(BaseOptimizer):
# it will not manage the tensors used by mixed precision training
self._param_store = ParameterStore(ParallelMode.ZERO1)
self._grad_store = GradientStore(ParallelMode.DATA)
self._bucket_store = []
self._bucket_store_2 = []
self._bucket_store: List[BucketStore] = []
self._accum_grad_buckets: List[BucketStore] = []
self._bucket_in_progress = []
# fp16 and fp32 params for mixed precision training
@ -160,7 +173,7 @@ class HybridZeroOptimizer(BaseOptimizer):
# TODO _broadcast_parallel_mode is not only used in broadcast, maybe can change its name
self._broadcast_parallel_mode.append(zero_mode)
self._bucket_store.append(BucketStore(group_id, param_group["dp_mode"]))
self._bucket_store_2.append(BucketStore(group_id, param_group["dp_mode"]))
self._accum_grad_buckets.append(BucketStore(group_id, param_group["dp_mode"]))
# assign parameters to ranks the params in the list are sorted
params_per_rank, no_params_ranks = self._partition_param_list(group_id, param_group)
@ -306,9 +319,9 @@ class HybridZeroOptimizer(BaseOptimizer):
param=param,
reduce_rank=reduce_rank,
)
reduce_scatter_checker = partial(
self._wait_reduce_scatter_and_accumulate_grad,
self._wait_reduce_scatter_and_accumulate_grads,
param=param,
reduce_rank=reduce_rank,
)
@ -317,7 +330,7 @@ class HybridZeroOptimizer(BaseOptimizer):
# NOT IMPORTANT BUT GOOD TO KNOW:
# args here is not grad, but allow_unreacable and accumulate_grad
def reduce_grad_hook(*args): # pylint: disable=W0613
if gpc.config.fstp_handler is not None:
if self._fstp_handler is not None:
reduce_scatter_checker()
if self.skip_grad_reduce is False:
@ -341,84 +354,36 @@ class HybridZeroOptimizer(BaseOptimizer):
group_id = getattr(param, "group_id")
return tensor_rank == gpc.get_local_rank(self._broadcast_parallel_mode[group_id])
def reset_reduce_bucket(self) -> None:
for bucket in self._bucket_store_2:
for rank, params in bucket._params.items():
for _param in params:
if not hasattr(_param, "_fstp_reduce_scatter_str"):
continue
def _accum_grads_store_in_bucket(self, bucket: BucketStore, reduce_rank: Optional[int] = None) -> None:
for _param in bucket.get_param(reduce_rank):
if not hasattr(_param, "_fstp_reduce_scatter_str"):
continue
key = getattr(_param, "_fstp_reduce_scatter_str")
comm_handle, _grad = self._fstp_handler.reduce_scatter_handlers[key]
comm_handle.wait()
_param.grad.add_(_grad)
# self._fstp_handler.reduce_scatter_handlers[key] = None
# del _grad
release_reduce_scatter_memory_pool(size=tuple(_grad.size()),index=_grad.index)
del self._fstp_handler.reduce_scatter_handlers[key]
self._fstp_handler.reduce_scatter_handlers[key] = None
assert key in self._fstp_handler.reduce_scatter_handlers
# if not hasattr(_param, "_fstp_all_reduce_str"):
# continue
# wait and accumulate gardient.
_key = getattr(_param, "_fstp_reduce_scatter_str")
_comm_handle, _grad = self._fstp_handler.reduce_scatter_handlers[_key]
_comm_handle.wait()
_param.grad.add_(_grad)
# key = getattr(_param, "_fstp_all_reduce_str")
# comm_handle, _grad = self._fstp_handler.all_reduce_handlers[key]
# comm_handle.wait()
# with torch.no_grad():
# _grad = split_forward_gather_backward(_grad, ParallelMode.TENSOR, dim=0)
# _param.grad.add_(_grad)
# # self._fstp_handler.reduce_scatter_handlers[key] = None
# del _grad
# del self._fstp_handler.all_reduce_handlers[key]
# self._fstp_handler.all_reduce_handlers[key] = None
# assert key in self._fstp_handler.all_reduce_handlers
# release cuda memory.
self._fstp_handler.reduce_scatter_handlers[_key] = None
_grad = None
bucket.reset_by_rank(rank)
def _wait_reduce_scatter_and_accumulate_grad(self, param, reduce_rank=None):
bucket.reset_by_rank(reduce_rank)
def _wait_reduce_scatter_and_accumulate_grads(self, param, reduce_rank: Optional[int] = None):
param_size = param.numel()
group_id = getattr(param, "group_id")
current_bucket = self._accum_grad_buckets[group_id]
# check if the bucket is full
# if full, will reduce the grads already in the bucket
# after reduction, the bucket will be empty
group_id = getattr(param, "group_id")
current_bucket = self._bucket_store_2[group_id]
if current_bucket.num_elements_in_bucket(reduce_rank) >= self._reduce_bucket_size:
self._accum_grads_store_in_bucket(current_bucket, reduce_rank)
if current_bucket.num_elements_in_bucket(reduce_rank) >= 512 * 1024 * 1024:
# wait reduce scatter communication
params = current_bucket.get_param(reduce_rank)
for _param in params:
if not hasattr(_param, "_fstp_reduce_scatter_str"):
continue
key = getattr(_param, "_fstp_reduce_scatter_str")
comm_handle, _grad = self._fstp_handler.reduce_scatter_handlers[key]
comm_handle.wait()
_param.grad.add_(_grad)
# self._fstp_handler.reduce_scatter_handlers[key] = None
# del _grad
release_reduce_scatter_memory_pool(size=tuple(_grad.size()),index=_grad.index)
del self._fstp_handler.reduce_scatter_handlers[key]
self._fstp_handler.reduce_scatter_handlers[key] = None
assert key in self._fstp_handler.reduce_scatter_handlers
# if not hasattr(_param, "_fstp_all_reduce_str"):
# continue
# key = getattr(_param, "_fstp_all_reduce_str")
# comm_handle, _grad = self._fstp_handler.all_reduce_handlers[key]
# comm_handle.wait()
# with torch.no_grad():
# _grad = split_forward_gather_backward(_grad, ParallelMode.TENSOR, dim=0)
# _param.grad.add_(_grad)
# # self._fstp_handler.reduce_scatter_handlers[key] = None
# del _grad
# del self._fstp_handler.all_reduce_handlers[key]
# self._fstp_handler.all_reduce_handlers[key] = None
# assert key in self._fstp_handler.all_reduce_handlers
current_bucket.reset_by_rank(reduce_rank)
# otherwise, add the parameter into bucket.
current_bucket.add_num_elements_in_bucket(param_size, reduce_rank)
current_bucket.add_param(param, reduce_rank)
@ -646,6 +611,10 @@ class HybridZeroOptimizer(BaseOptimizer):
for group_id in range(self.num_param_groups):
self._reduce_grads_stored_in_bucket(self._bucket_store[group_id], reduce_rank=None, last_bucket=True)
# we need to accumulate gradients left in the accumulate gardient bucket
for group_id in range(self.num_param_groups):
self._accum_grads_store_in_bucket(self._accum_grad_buckets[group_id], reduce_rank=None)
# compute norm for gradients in the before bucket
groups_norms = []
for group_id in range(self.num_param_groups):
@ -685,16 +654,16 @@ class HybridZeroOptimizer(BaseOptimizer):
timer("sync_grad").start()
self._sync_grad()
timer("sync_grad").stop()
print_memory("No 4")
try:
res = self._step(closure=closure, norms=total_norms)
res = self._step(closure=closure, norms=total_norms)
except torch.cuda.OutOfMemoryError as e:
print(e, flush=True)
print(torch.cuda.memory_summary(), flush=True)
torch.cuda.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle")
return res
def _step(self, closure=None, norms=None):
@ -822,7 +791,7 @@ class HybridZeroOptimizer(BaseOptimizer):
torch.cuda.synchronize()
with torch.cuda.stream(self._comm_bcast_stream):
self.broadcast_params()
timer("step").stop()
# update gradients may not be needed here, because the sync_params function is used in initialization,