memory profiling test

pull/407/head
yingtongxiong 2023-10-17 19:54:21 +08:00
parent 16ef7b7889
commit a5aeab2a3f
4 changed files with 28 additions and 19 deletions

View File

@ -1,4 +1,4 @@
JOB_NAME = "13b_train"
JOB_NAME = "20b_train"
DO_ALERT = False
SEQ_LEN = 4096
@ -51,7 +51,7 @@ data = dict(
# micro_num means the number of micro_batch contained in one gradient update
micro_num=4,
# packed_length = micro_bsz * SEQ_LEN
micro_bsz=4,
micro_bsz=2,
# defaults to the value of micro_num
valid_micro_num=4,
# defaults to 0, means disable evaluate

View File

@ -423,7 +423,6 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
self.process_group = process_group
self.FSTP_blocks = []
self.FSTP_outs = []
self.FSTP_wqkvs = []
self.FSTP_modules = []
self.module_name = ["Wqkv", "out_proj", "w1", "w2", "w3"]
self.FSTP_global_handle = dict() # key: FSTP module; value: module global all-gather op handle
@ -465,9 +464,6 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
if name == "out_proj":
self.FSTP_outs.append(child)
self.module_to_index[child] = idx
if name == "Wqkv":
self.FSTP_wqkvs.append(child)
self.module_to_index[child] = idx
if isinstance(child, FSTPLinear):
self.module_to_index[child] = idx
self.block_module[idx][index] = child
@ -488,7 +484,7 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
self.embedding.append(children)
def _all_gather_block_weight(self, block_index: int):
block = self.index_to_block[block_index]
#block = self.index_to_block[block_index]
fsdp_modules = self.index_to_fsdp_modules[block_index]
# self.block_handles[block] = []
for module in fsdp_modules:
@ -552,12 +548,6 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
if module in self.FSTP_global_handle:
del self.FSTP_global_handle[module]
def _pre_backward_hook_for_wqkv(module: nn.Module, grad_output):
block_index = self.module_to_index[module]
# start the all-gather for next block
if block_index - 1 >= 0:
self._all_gather_block_weight(block_index - 1)
def _pre_backward_hook_for_block(block: nn.Module, grad_output):
# import pdb; pdb.set_trace()
block_index = self.block_to_index[block]

View File

@ -39,6 +39,14 @@ from .utils import compute_norm
inf = math.inf
logger = get_logger(__file__)
def print_memory(msg):
if gpc.get_global_rank() == 0:
print(msg, flush=True)
print("memory allocated: ", torch.cuda.memory_allocated() / 1024 / 1024 / 1024, flush=True)
print("max memory allocated: ", torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024, flush=True)
print("===========================================")
class HybridZeroOptimizer(BaseOptimizer):
"""
@ -335,6 +343,7 @@ class HybridZeroOptimizer(BaseOptimizer):
comm_handle, _grad = self._fstp_handler.reduce_scatter_handlers[key]
comm_handle.wait()
_param.grad += _grad
self._fstp_handler.reduce_scatter_handlers[key] = None
bucket.reset_by_rank(rank)
@ -358,6 +367,7 @@ class HybridZeroOptimizer(BaseOptimizer):
comm_handle, _grad = self._fstp_handler.reduce_scatter_handlers[key]
comm_handle.wait()
_param.grad += _grad
self._fstp_handler.reduce_scatter_handlers[key] = None
# reduce grad
if self.skip_grad_reduce is False:
@ -565,6 +575,7 @@ class HybridZeroOptimizer(BaseOptimizer):
# if not overlapping communication (no reduction hook is attached)
# we need to manually reduce these gradients
print_memory("No 1")
if not self._overlap_sync_grad:
for group_id in range(len(self._fp16_param_groups)):
for param in self._fp16_param_groups[group_id]:
@ -589,7 +600,7 @@ class HybridZeroOptimizer(BaseOptimizer):
bucket.empty()
self._bucket_in_progress = []
self._param_store.clear_grads_of_previous_reduced_params()
print_memory("No 2")
# compute norm for gradients in the last bucket
total_norms = {}
for group_id in range(self.num_param_groups):
@ -611,10 +622,12 @@ class HybridZeroOptimizer(BaseOptimizer):
scaled_norm_tensor = torch.tensor(scaled_norm, device=get_current_device(), dtype=torch.float)
dist.all_reduce(scaled_norm_tensor, group=pg)
total_norms[group_name] = scaled_norm_tensor.item()
print_memory("No 3")
timer("sync_grad").start()
self._sync_grad()
timer("sync_grad").stop()
print_memory("No 4")
return self._step(closure=closure, norms=total_norms)
@ -661,7 +674,7 @@ class HybridZeroOptimizer(BaseOptimizer):
self._grad_store._averaged_gradients = dict()
self.zero_grad()
return False, norms
print_memory("No 5")
# copy the grad of fp16 param to fp32 param
single_grad_partition_groups = []
for group_id in range(self.num_param_groups):
@ -702,7 +715,7 @@ class HybridZeroOptimizer(BaseOptimizer):
single_grad_partition_groups.append(flat_fp32_avg_grads)
device = self._fp32_flat_param_groups_of_current_rank[group_id].device
self._fp32_flat_param_groups_of_current_rank[group_id].grad = flat_fp32_avg_grads.to(device)
print_memory("No 6")
# unscale and clip grads
# get the global norm
global_norm_groups = {}
@ -725,9 +738,12 @@ class HybridZeroOptimizer(BaseOptimizer):
# For those ranks that are not assigned parameters, we just wait for other ranks
# to send them updated their own parameters.
if self.has_params:
print_memory("No 7")
self.optim.step()
print_memory("No 8")
# release the fp32 grad
release_param_grad(self._fp32_flat_param_groups_of_current_rank.values())
print_memory("No 9")
# update fp16 partition updated by the current rank
for group_id in range(len(self._fp16_param_groups)):
if self.param_group_has_params[group_id]:
@ -736,17 +752,18 @@ class HybridZeroOptimizer(BaseOptimizer):
)
fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id]
fp16_param.data.copy_(fp32_param)
print_memory("No 10")
torch.cuda.synchronize()
with torch.cuda.stream(self._comm_bcast_stream):
self.broadcast_params()
timer("step").stop()
# update gradients may not be needed here, because the sync_params function is used in initialization,
# so synchronization is maintained
for group_name, global_norm in global_norm_groups.items():
global_norm_groups[group_name] = global_norm / loss_scale
print_memory("No 11")
return True, global_norm_groups
def broadcast_params(self):

View File

@ -296,6 +296,8 @@ def main(args):
if batch_count % 2 == 0:
prof.step()
torch.cuda.reset_peak_memory_stats()
ckpt_manager.wait_async_upload_finish()