mirror of https://github.com/InternLM/InternLM
memory profiling test
parent
16ef7b7889
commit
a5aeab2a3f
|
@ -1,4 +1,4 @@
|
||||||
JOB_NAME = "13b_train"
|
JOB_NAME = "20b_train"
|
||||||
DO_ALERT = False
|
DO_ALERT = False
|
||||||
|
|
||||||
SEQ_LEN = 4096
|
SEQ_LEN = 4096
|
||||||
|
@ -51,7 +51,7 @@ data = dict(
|
||||||
# micro_num means the number of micro_batch contained in one gradient update
|
# micro_num means the number of micro_batch contained in one gradient update
|
||||||
micro_num=4,
|
micro_num=4,
|
||||||
# packed_length = micro_bsz * SEQ_LEN
|
# packed_length = micro_bsz * SEQ_LEN
|
||||||
micro_bsz=4,
|
micro_bsz=2,
|
||||||
# defaults to the value of micro_num
|
# defaults to the value of micro_num
|
||||||
valid_micro_num=4,
|
valid_micro_num=4,
|
||||||
# defaults to 0, means disable evaluate
|
# defaults to 0, means disable evaluate
|
||||||
|
|
|
@ -423,7 +423,6 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
|
||||||
self.process_group = process_group
|
self.process_group = process_group
|
||||||
self.FSTP_blocks = []
|
self.FSTP_blocks = []
|
||||||
self.FSTP_outs = []
|
self.FSTP_outs = []
|
||||||
self.FSTP_wqkvs = []
|
|
||||||
self.FSTP_modules = []
|
self.FSTP_modules = []
|
||||||
self.module_name = ["Wqkv", "out_proj", "w1", "w2", "w3"]
|
self.module_name = ["Wqkv", "out_proj", "w1", "w2", "w3"]
|
||||||
self.FSTP_global_handle = dict() # key: FSTP module; value: module global all-gather op handle
|
self.FSTP_global_handle = dict() # key: FSTP module; value: module global all-gather op handle
|
||||||
|
@ -465,9 +464,6 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
|
||||||
if name == "out_proj":
|
if name == "out_proj":
|
||||||
self.FSTP_outs.append(child)
|
self.FSTP_outs.append(child)
|
||||||
self.module_to_index[child] = idx
|
self.module_to_index[child] = idx
|
||||||
if name == "Wqkv":
|
|
||||||
self.FSTP_wqkvs.append(child)
|
|
||||||
self.module_to_index[child] = idx
|
|
||||||
if isinstance(child, FSTPLinear):
|
if isinstance(child, FSTPLinear):
|
||||||
self.module_to_index[child] = idx
|
self.module_to_index[child] = idx
|
||||||
self.block_module[idx][index] = child
|
self.block_module[idx][index] = child
|
||||||
|
@ -488,7 +484,7 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
|
||||||
self.embedding.append(children)
|
self.embedding.append(children)
|
||||||
|
|
||||||
def _all_gather_block_weight(self, block_index: int):
|
def _all_gather_block_weight(self, block_index: int):
|
||||||
block = self.index_to_block[block_index]
|
#block = self.index_to_block[block_index]
|
||||||
fsdp_modules = self.index_to_fsdp_modules[block_index]
|
fsdp_modules = self.index_to_fsdp_modules[block_index]
|
||||||
# self.block_handles[block] = []
|
# self.block_handles[block] = []
|
||||||
for module in fsdp_modules:
|
for module in fsdp_modules:
|
||||||
|
@ -552,12 +548,6 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
|
||||||
if module in self.FSTP_global_handle:
|
if module in self.FSTP_global_handle:
|
||||||
del self.FSTP_global_handle[module]
|
del self.FSTP_global_handle[module]
|
||||||
|
|
||||||
def _pre_backward_hook_for_wqkv(module: nn.Module, grad_output):
|
|
||||||
block_index = self.module_to_index[module]
|
|
||||||
# start the all-gather for next block
|
|
||||||
if block_index - 1 >= 0:
|
|
||||||
self._all_gather_block_weight(block_index - 1)
|
|
||||||
|
|
||||||
def _pre_backward_hook_for_block(block: nn.Module, grad_output):
|
def _pre_backward_hook_for_block(block: nn.Module, grad_output):
|
||||||
# import pdb; pdb.set_trace()
|
# import pdb; pdb.set_trace()
|
||||||
block_index = self.block_to_index[block]
|
block_index = self.block_to_index[block]
|
||||||
|
|
|
@ -39,6 +39,14 @@ from .utils import compute_norm
|
||||||
inf = math.inf
|
inf = math.inf
|
||||||
logger = get_logger(__file__)
|
logger = get_logger(__file__)
|
||||||
|
|
||||||
|
def print_memory(msg):
|
||||||
|
|
||||||
|
if gpc.get_global_rank() == 0:
|
||||||
|
print(msg, flush=True)
|
||||||
|
print("memory allocated: ", torch.cuda.memory_allocated() / 1024 / 1024 / 1024, flush=True)
|
||||||
|
print("max memory allocated: ", torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024, flush=True)
|
||||||
|
print("===========================================")
|
||||||
|
|
||||||
|
|
||||||
class HybridZeroOptimizer(BaseOptimizer):
|
class HybridZeroOptimizer(BaseOptimizer):
|
||||||
"""
|
"""
|
||||||
|
@ -335,6 +343,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
comm_handle, _grad = self._fstp_handler.reduce_scatter_handlers[key]
|
comm_handle, _grad = self._fstp_handler.reduce_scatter_handlers[key]
|
||||||
comm_handle.wait()
|
comm_handle.wait()
|
||||||
_param.grad += _grad
|
_param.grad += _grad
|
||||||
|
self._fstp_handler.reduce_scatter_handlers[key] = None
|
||||||
|
|
||||||
bucket.reset_by_rank(rank)
|
bucket.reset_by_rank(rank)
|
||||||
|
|
||||||
|
@ -358,6 +367,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
comm_handle, _grad = self._fstp_handler.reduce_scatter_handlers[key]
|
comm_handle, _grad = self._fstp_handler.reduce_scatter_handlers[key]
|
||||||
comm_handle.wait()
|
comm_handle.wait()
|
||||||
_param.grad += _grad
|
_param.grad += _grad
|
||||||
|
self._fstp_handler.reduce_scatter_handlers[key] = None
|
||||||
|
|
||||||
# reduce grad
|
# reduce grad
|
||||||
if self.skip_grad_reduce is False:
|
if self.skip_grad_reduce is False:
|
||||||
|
@ -565,6 +575,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
|
|
||||||
# if not overlapping communication (no reduction hook is attached)
|
# if not overlapping communication (no reduction hook is attached)
|
||||||
# we need to manually reduce these gradients
|
# we need to manually reduce these gradients
|
||||||
|
print_memory("No 1")
|
||||||
if not self._overlap_sync_grad:
|
if not self._overlap_sync_grad:
|
||||||
for group_id in range(len(self._fp16_param_groups)):
|
for group_id in range(len(self._fp16_param_groups)):
|
||||||
for param in self._fp16_param_groups[group_id]:
|
for param in self._fp16_param_groups[group_id]:
|
||||||
|
@ -589,7 +600,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
bucket.empty()
|
bucket.empty()
|
||||||
self._bucket_in_progress = []
|
self._bucket_in_progress = []
|
||||||
self._param_store.clear_grads_of_previous_reduced_params()
|
self._param_store.clear_grads_of_previous_reduced_params()
|
||||||
|
print_memory("No 2")
|
||||||
# compute norm for gradients in the last bucket
|
# compute norm for gradients in the last bucket
|
||||||
total_norms = {}
|
total_norms = {}
|
||||||
for group_id in range(self.num_param_groups):
|
for group_id in range(self.num_param_groups):
|
||||||
|
@ -611,10 +622,12 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
scaled_norm_tensor = torch.tensor(scaled_norm, device=get_current_device(), dtype=torch.float)
|
scaled_norm_tensor = torch.tensor(scaled_norm, device=get_current_device(), dtype=torch.float)
|
||||||
dist.all_reduce(scaled_norm_tensor, group=pg)
|
dist.all_reduce(scaled_norm_tensor, group=pg)
|
||||||
total_norms[group_name] = scaled_norm_tensor.item()
|
total_norms[group_name] = scaled_norm_tensor.item()
|
||||||
|
print_memory("No 3")
|
||||||
timer("sync_grad").start()
|
timer("sync_grad").start()
|
||||||
self._sync_grad()
|
self._sync_grad()
|
||||||
timer("sync_grad").stop()
|
timer("sync_grad").stop()
|
||||||
|
|
||||||
|
print_memory("No 4")
|
||||||
|
|
||||||
return self._step(closure=closure, norms=total_norms)
|
return self._step(closure=closure, norms=total_norms)
|
||||||
|
|
||||||
|
@ -661,7 +674,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
self._grad_store._averaged_gradients = dict()
|
self._grad_store._averaged_gradients = dict()
|
||||||
self.zero_grad()
|
self.zero_grad()
|
||||||
return False, norms
|
return False, norms
|
||||||
|
print_memory("No 5")
|
||||||
# copy the grad of fp16 param to fp32 param
|
# copy the grad of fp16 param to fp32 param
|
||||||
single_grad_partition_groups = []
|
single_grad_partition_groups = []
|
||||||
for group_id in range(self.num_param_groups):
|
for group_id in range(self.num_param_groups):
|
||||||
|
@ -702,7 +715,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
single_grad_partition_groups.append(flat_fp32_avg_grads)
|
single_grad_partition_groups.append(flat_fp32_avg_grads)
|
||||||
device = self._fp32_flat_param_groups_of_current_rank[group_id].device
|
device = self._fp32_flat_param_groups_of_current_rank[group_id].device
|
||||||
self._fp32_flat_param_groups_of_current_rank[group_id].grad = flat_fp32_avg_grads.to(device)
|
self._fp32_flat_param_groups_of_current_rank[group_id].grad = flat_fp32_avg_grads.to(device)
|
||||||
|
print_memory("No 6")
|
||||||
# unscale and clip grads
|
# unscale and clip grads
|
||||||
# get the global norm
|
# get the global norm
|
||||||
global_norm_groups = {}
|
global_norm_groups = {}
|
||||||
|
@ -725,9 +738,12 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
# For those ranks that are not assigned parameters, we just wait for other ranks
|
# For those ranks that are not assigned parameters, we just wait for other ranks
|
||||||
# to send them updated their own parameters.
|
# to send them updated their own parameters.
|
||||||
if self.has_params:
|
if self.has_params:
|
||||||
|
print_memory("No 7")
|
||||||
self.optim.step()
|
self.optim.step()
|
||||||
|
print_memory("No 8")
|
||||||
# release the fp32 grad
|
# release the fp32 grad
|
||||||
release_param_grad(self._fp32_flat_param_groups_of_current_rank.values())
|
release_param_grad(self._fp32_flat_param_groups_of_current_rank.values())
|
||||||
|
print_memory("No 9")
|
||||||
# update fp16 partition updated by the current rank
|
# update fp16 partition updated by the current rank
|
||||||
for group_id in range(len(self._fp16_param_groups)):
|
for group_id in range(len(self._fp16_param_groups)):
|
||||||
if self.param_group_has_params[group_id]:
|
if self.param_group_has_params[group_id]:
|
||||||
|
@ -736,17 +752,18 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
)
|
)
|
||||||
fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id]
|
fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id]
|
||||||
fp16_param.data.copy_(fp32_param)
|
fp16_param.data.copy_(fp32_param)
|
||||||
|
print_memory("No 10")
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
with torch.cuda.stream(self._comm_bcast_stream):
|
with torch.cuda.stream(self._comm_bcast_stream):
|
||||||
self.broadcast_params()
|
self.broadcast_params()
|
||||||
|
|
||||||
timer("step").stop()
|
timer("step").stop()
|
||||||
|
|
||||||
# update gradients may not be needed here, because the sync_params function is used in initialization,
|
# update gradients may not be needed here, because the sync_params function is used in initialization,
|
||||||
# so synchronization is maintained
|
# so synchronization is maintained
|
||||||
for group_name, global_norm in global_norm_groups.items():
|
for group_name, global_norm in global_norm_groups.items():
|
||||||
global_norm_groups[group_name] = global_norm / loss_scale
|
global_norm_groups[group_name] = global_norm / loss_scale
|
||||||
|
print_memory("No 11")
|
||||||
return True, global_norm_groups
|
return True, global_norm_groups
|
||||||
|
|
||||||
def broadcast_params(self):
|
def broadcast_params(self):
|
||||||
|
|
Loading…
Reference in New Issue