mirror of https://github.com/InternLM/InternLM
memory profiling test
parent
16ef7b7889
commit
a5aeab2a3f
|
@ -1,4 +1,4 @@
|
|||
JOB_NAME = "13b_train"
|
||||
JOB_NAME = "20b_train"
|
||||
DO_ALERT = False
|
||||
|
||||
SEQ_LEN = 4096
|
||||
|
@ -51,7 +51,7 @@ data = dict(
|
|||
# micro_num means the number of micro_batch contained in one gradient update
|
||||
micro_num=4,
|
||||
# packed_length = micro_bsz * SEQ_LEN
|
||||
micro_bsz=4,
|
||||
micro_bsz=2,
|
||||
# defaults to the value of micro_num
|
||||
valid_micro_num=4,
|
||||
# defaults to 0, means disable evaluate
|
||||
|
|
|
@ -423,7 +423,6 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
|
|||
self.process_group = process_group
|
||||
self.FSTP_blocks = []
|
||||
self.FSTP_outs = []
|
||||
self.FSTP_wqkvs = []
|
||||
self.FSTP_modules = []
|
||||
self.module_name = ["Wqkv", "out_proj", "w1", "w2", "w3"]
|
||||
self.FSTP_global_handle = dict() # key: FSTP module; value: module global all-gather op handle
|
||||
|
@ -465,9 +464,6 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
|
|||
if name == "out_proj":
|
||||
self.FSTP_outs.append(child)
|
||||
self.module_to_index[child] = idx
|
||||
if name == "Wqkv":
|
||||
self.FSTP_wqkvs.append(child)
|
||||
self.module_to_index[child] = idx
|
||||
if isinstance(child, FSTPLinear):
|
||||
self.module_to_index[child] = idx
|
||||
self.block_module[idx][index] = child
|
||||
|
@ -488,7 +484,7 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
|
|||
self.embedding.append(children)
|
||||
|
||||
def _all_gather_block_weight(self, block_index: int):
|
||||
block = self.index_to_block[block_index]
|
||||
#block = self.index_to_block[block_index]
|
||||
fsdp_modules = self.index_to_fsdp_modules[block_index]
|
||||
# self.block_handles[block] = []
|
||||
for module in fsdp_modules:
|
||||
|
@ -552,12 +548,6 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
|
|||
if module in self.FSTP_global_handle:
|
||||
del self.FSTP_global_handle[module]
|
||||
|
||||
def _pre_backward_hook_for_wqkv(module: nn.Module, grad_output):
|
||||
block_index = self.module_to_index[module]
|
||||
# start the all-gather for next block
|
||||
if block_index - 1 >= 0:
|
||||
self._all_gather_block_weight(block_index - 1)
|
||||
|
||||
def _pre_backward_hook_for_block(block: nn.Module, grad_output):
|
||||
# import pdb; pdb.set_trace()
|
||||
block_index = self.block_to_index[block]
|
||||
|
|
|
@ -39,6 +39,14 @@ from .utils import compute_norm
|
|||
inf = math.inf
|
||||
logger = get_logger(__file__)
|
||||
|
||||
def print_memory(msg):
|
||||
|
||||
if gpc.get_global_rank() == 0:
|
||||
print(msg, flush=True)
|
||||
print("memory allocated: ", torch.cuda.memory_allocated() / 1024 / 1024 / 1024, flush=True)
|
||||
print("max memory allocated: ", torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024, flush=True)
|
||||
print("===========================================")
|
||||
|
||||
|
||||
class HybridZeroOptimizer(BaseOptimizer):
|
||||
"""
|
||||
|
@ -335,6 +343,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
comm_handle, _grad = self._fstp_handler.reduce_scatter_handlers[key]
|
||||
comm_handle.wait()
|
||||
_param.grad += _grad
|
||||
self._fstp_handler.reduce_scatter_handlers[key] = None
|
||||
|
||||
bucket.reset_by_rank(rank)
|
||||
|
||||
|
@ -358,6 +367,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
comm_handle, _grad = self._fstp_handler.reduce_scatter_handlers[key]
|
||||
comm_handle.wait()
|
||||
_param.grad += _grad
|
||||
self._fstp_handler.reduce_scatter_handlers[key] = None
|
||||
|
||||
# reduce grad
|
||||
if self.skip_grad_reduce is False:
|
||||
|
@ -565,6 +575,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
|
||||
# if not overlapping communication (no reduction hook is attached)
|
||||
# we need to manually reduce these gradients
|
||||
print_memory("No 1")
|
||||
if not self._overlap_sync_grad:
|
||||
for group_id in range(len(self._fp16_param_groups)):
|
||||
for param in self._fp16_param_groups[group_id]:
|
||||
|
@ -589,7 +600,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
bucket.empty()
|
||||
self._bucket_in_progress = []
|
||||
self._param_store.clear_grads_of_previous_reduced_params()
|
||||
|
||||
print_memory("No 2")
|
||||
# compute norm for gradients in the last bucket
|
||||
total_norms = {}
|
||||
for group_id in range(self.num_param_groups):
|
||||
|
@ -611,10 +622,12 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
scaled_norm_tensor = torch.tensor(scaled_norm, device=get_current_device(), dtype=torch.float)
|
||||
dist.all_reduce(scaled_norm_tensor, group=pg)
|
||||
total_norms[group_name] = scaled_norm_tensor.item()
|
||||
|
||||
print_memory("No 3")
|
||||
timer("sync_grad").start()
|
||||
self._sync_grad()
|
||||
timer("sync_grad").stop()
|
||||
|
||||
print_memory("No 4")
|
||||
|
||||
return self._step(closure=closure, norms=total_norms)
|
||||
|
||||
|
@ -661,7 +674,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
self._grad_store._averaged_gradients = dict()
|
||||
self.zero_grad()
|
||||
return False, norms
|
||||
|
||||
print_memory("No 5")
|
||||
# copy the grad of fp16 param to fp32 param
|
||||
single_grad_partition_groups = []
|
||||
for group_id in range(self.num_param_groups):
|
||||
|
@ -702,7 +715,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
single_grad_partition_groups.append(flat_fp32_avg_grads)
|
||||
device = self._fp32_flat_param_groups_of_current_rank[group_id].device
|
||||
self._fp32_flat_param_groups_of_current_rank[group_id].grad = flat_fp32_avg_grads.to(device)
|
||||
|
||||
print_memory("No 6")
|
||||
# unscale and clip grads
|
||||
# get the global norm
|
||||
global_norm_groups = {}
|
||||
|
@ -725,9 +738,12 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
# For those ranks that are not assigned parameters, we just wait for other ranks
|
||||
# to send them updated their own parameters.
|
||||
if self.has_params:
|
||||
print_memory("No 7")
|
||||
self.optim.step()
|
||||
print_memory("No 8")
|
||||
# release the fp32 grad
|
||||
release_param_grad(self._fp32_flat_param_groups_of_current_rank.values())
|
||||
print_memory("No 9")
|
||||
# update fp16 partition updated by the current rank
|
||||
for group_id in range(len(self._fp16_param_groups)):
|
||||
if self.param_group_has_params[group_id]:
|
||||
|
@ -736,17 +752,18 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
)
|
||||
fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id]
|
||||
fp16_param.data.copy_(fp32_param)
|
||||
|
||||
print_memory("No 10")
|
||||
torch.cuda.synchronize()
|
||||
with torch.cuda.stream(self._comm_bcast_stream):
|
||||
self.broadcast_params()
|
||||
|
||||
|
||||
timer("step").stop()
|
||||
|
||||
# update gradients may not be needed here, because the sync_params function is used in initialization,
|
||||
# so synchronization is maintained
|
||||
for group_name, global_norm in global_norm_groups.items():
|
||||
global_norm_groups[group_name] = global_norm / loss_scale
|
||||
print_memory("No 11")
|
||||
return True, global_norm_groups
|
||||
|
||||
def broadcast_params(self):
|
||||
|
|
Loading…
Reference in New Issue