mirror of https://github.com/InternLM/InternLM
Merge branch 'feat/fstp' of https://github.com/yingtongxiong/InternLM into feat/fstp
commit
a5c6e457b9
|
@ -194,6 +194,7 @@ class NonPipelineScheduler(BaseScheduler):
|
||||||
_output, _loss, _moe_loss = self._train_one_batch(
|
_output, _loss, _moe_loss = self._train_one_batch(
|
||||||
_data, _label, engine, forward_only, return_loss, self._grad_accum_size
|
_data, _label, engine, forward_only, return_loss, self._grad_accum_size
|
||||||
)
|
)
|
||||||
|
engine.optimizer.reset_reduce_bucket()
|
||||||
|
|
||||||
if return_loss:
|
if return_loss:
|
||||||
loss += _loss
|
loss += _loss
|
||||||
|
|
|
@ -329,6 +329,8 @@ class FSTPAllGatherSyncHandler:
|
||||||
self.block_module = dict() # key: transformer block index; value: {name_index: FSTP module}
|
self.block_module = dict() # key: transformer block index; value: {name_index: FSTP module}
|
||||||
self.module_name_index = dict() # key: FSTP module; value: the name in index in self.module_name
|
self.module_name_index = dict() # key: FSTP module; value: the name in index in self.module_name
|
||||||
|
|
||||||
|
self.reduce_scatter_handlers = {}
|
||||||
|
|
||||||
# just want to share same for loop for ModuleList and Module
|
# just want to share same for loop for ModuleList and Module
|
||||||
if not isinstance(model, nn.ModuleList):
|
if not isinstance(model, nn.ModuleList):
|
||||||
model = [model]
|
model = [model]
|
||||||
|
@ -337,16 +339,22 @@ class FSTPAllGatherSyncHandler:
|
||||||
if isinstance(_chunk, NaiveAMPModel):
|
if isinstance(_chunk, NaiveAMPModel):
|
||||||
_chunk = _chunk.model
|
_chunk = _chunk.model
|
||||||
|
|
||||||
for _, children in _chunk.named_children():
|
for _chunk_name, children in _chunk.named_children():
|
||||||
if isinstance(children, nn.ModuleList):
|
if isinstance(children, nn.ModuleList):
|
||||||
for idx, block in enumerate(children):
|
for idx, block in enumerate(children):
|
||||||
index = 0
|
index = 0
|
||||||
self.block_module[idx] = {}
|
self.block_module[idx] = {}
|
||||||
for _, sub in block.named_children():
|
for _sub_name, sub in block.named_children():
|
||||||
sub_modules = list(sub.children())
|
sub_modules = list(sub.children())
|
||||||
if len(sub_modules) > 0:
|
if len(sub_modules) > 0:
|
||||||
for name, child in sub.named_children():
|
for name, child in sub.named_children():
|
||||||
if isinstance(child, FSTPLinear):
|
if isinstance(child, FSTPLinear):
|
||||||
|
|
||||||
|
_full_name = f"{_chunk_name}.{idx}.{_sub_name}.{name}"
|
||||||
|
setattr(child.weight, "_fstp_reduce_scatter_str", f"{_full_name}.weight")
|
||||||
|
if child.bias is not None:
|
||||||
|
setattr(child.bias, "_fstp_reduce_scatter_str", f"{_full_name}.bias")
|
||||||
|
|
||||||
self.FSTP_modules.append(child)
|
self.FSTP_modules.append(child)
|
||||||
self.module_block[child] = idx
|
self.module_block[child] = idx
|
||||||
self.block_module[idx][index] = child
|
self.block_module[idx][index] = child
|
||||||
|
@ -451,6 +459,8 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
|
||||||
self.block_module = dict() # key: transformer block index; value: {name_index: FSTP module}
|
self.block_module = dict() # key: transformer block index; value: {name_index: FSTP module}
|
||||||
self.head = []
|
self.head = []
|
||||||
|
|
||||||
|
self.reduce_scatter_handlers = {}
|
||||||
|
|
||||||
# just want to share same for loop for ModuleList and Module
|
# just want to share same for loop for ModuleList and Module
|
||||||
if not isinstance(model, nn.ModuleList):
|
if not isinstance(model, nn.ModuleList):
|
||||||
model = [model]
|
model = [model]
|
||||||
|
@ -459,7 +469,7 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
|
||||||
if isinstance(_chunk, NaiveAMPModel):
|
if isinstance(_chunk, NaiveAMPModel):
|
||||||
_chunk = _chunk.model
|
_chunk = _chunk.model
|
||||||
|
|
||||||
for _, children in _chunk.named_children():
|
for _chunk_name, children in _chunk.named_children():
|
||||||
if isinstance(children, nn.ModuleList):
|
if isinstance(children, nn.ModuleList):
|
||||||
for idx, block in enumerate(children):
|
for idx, block in enumerate(children):
|
||||||
index = 0
|
index = 0
|
||||||
|
@ -468,7 +478,7 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
|
||||||
self.block_to_index[block] = idx
|
self.block_to_index[block] = idx
|
||||||
self.index_to_block[idx] = block
|
self.index_to_block[idx] = block
|
||||||
self.index_to_fsdp_modules[idx] = []
|
self.index_to_fsdp_modules[idx] = []
|
||||||
for _, sub in block.named_children():
|
for _sub_name, sub in block.named_children():
|
||||||
sub_modules = list(sub.children())
|
sub_modules = list(sub.children())
|
||||||
if len(sub_modules) > 0:
|
if len(sub_modules) > 0:
|
||||||
for name, child in sub.named_children():
|
for name, child in sub.named_children():
|
||||||
|
@ -486,6 +496,11 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
|
||||||
self.index_to_fsdp_modules[idx].append(child)
|
self.index_to_fsdp_modules[idx].append(child)
|
||||||
self.module_name_index[child] = index
|
self.module_name_index[child] = index
|
||||||
index = index + 1
|
index = index + 1
|
||||||
|
|
||||||
|
_full_name = f"{_chunk_name}.{idx}.{_sub_name}.{name}"
|
||||||
|
setattr(child.weight, "_fstp_reduce_scatter_str", f"{_full_name}.weight")
|
||||||
|
if child.bias is not None:
|
||||||
|
setattr(child.bias, "_fstp_reduce_scatter_str", f"{_full_name}.bias")
|
||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
elif isinstance(children, ScaleColumnParallelLinear):
|
elif isinstance(children, ScaleColumnParallelLinear):
|
||||||
|
|
|
@ -324,9 +324,9 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
|
||||||
raise RuntimeError("fused_dense only supports matrix dims <= 2M")
|
raise RuntimeError("fused_dense only supports matrix dims <= 2M")
|
||||||
output = F.linear(total_x, total_weight, total_bias)
|
output = F.linear(total_x, total_weight, total_bias)
|
||||||
if ctx.compute_weight_gradient:
|
if ctx.compute_weight_gradient:
|
||||||
ctx.save_for_backward(x, weight)
|
ctx.save_for_backward(x, weight, bias)
|
||||||
else:
|
else:
|
||||||
ctx.save_for_backward(weight)
|
ctx.save_for_backward(weight, bias)
|
||||||
return output if not return_residual else (output, x)
|
return output if not return_residual else (output, x)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -340,10 +340,10 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
|
||||||
all_gather_handler = ctx.all_gather_handler
|
all_gather_handler = ctx.all_gather_handler
|
||||||
module = ctx.module
|
module = ctx.module
|
||||||
if ctx.compute_weight_gradient:
|
if ctx.compute_weight_gradient:
|
||||||
x, weight = ctx.saved_tensors
|
x, weight, bias = ctx.saved_tensors
|
||||||
total_x = x
|
total_x = x
|
||||||
else:
|
else:
|
||||||
(weight,) = ctx.saved_tensors
|
weight, bias = ctx.saved_tensors
|
||||||
total_x = None
|
total_x = None
|
||||||
batch_shape = grad_output.shape[:-1]
|
batch_shape = grad_output.shape[:-1]
|
||||||
batch_dim = batch_shape.numel()
|
batch_dim = batch_shape.numel()
|
||||||
|
@ -368,9 +368,15 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
|
||||||
total_x.reshape(batch_dim, total_x.shape[-1]), grad_output, ctx.needs_input_grad[2]
|
total_x.reshape(batch_dim, total_x.shape[-1]), grad_output, ctx.needs_input_grad[2]
|
||||||
)
|
)
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
grad_weight, handle_grad_weight = reduce_scatter_raw(grad_weight, process_group, async_op=True)
|
grad_weight_async, handle_grad_weight = reduce_scatter_raw(grad_weight, process_group, async_op=True)
|
||||||
|
assert hasattr(weight, "_fstp_reduce_scatter_str")
|
||||||
|
all_gather_handler.reduce_scatter_handlers[weight._fstp_reduce_scatter_str] = (handle_grad_weight, grad_weight_async)
|
||||||
|
grad_weight = torch.zeros(grad_weight.shape[0]//torch.distributed.get_world_size(process_group), *grad_weight.shape[1:], dtype=grad_weight.dtype, device=grad_weight.device)
|
||||||
if grad_bias is not None:
|
if grad_bias is not None:
|
||||||
grad_bias, handle_grad_bias = reduce_scatter_raw(grad_bias, process_group, async_op=True)
|
grad_bias_async, handle_grad_bias = reduce_scatter_raw(grad_bias, process_group, async_op=True)
|
||||||
|
assert hasattr(bias, "_fstp_reduce_scatter_str")
|
||||||
|
all_gather_handler.reduce_scatter_handlers[bias._fstp_reduce_scatter_str] = (handle_grad_bias, grad_bias_async)
|
||||||
|
grad_bias = torch.zeros(grad_bias.shape[0]//torch.distributed.get_world_size(process_group), *grad_bias.shape[1:], dtype=grad_bias.dtype, device=grad_bias.device)
|
||||||
else:
|
else:
|
||||||
grad_weight = None
|
grad_weight = None
|
||||||
grad_bias = grad_output if ctx.needs_input_grad[2] else None
|
grad_bias = grad_output if ctx.needs_input_grad[2] else None
|
||||||
|
@ -384,11 +390,11 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
|
||||||
else:
|
else:
|
||||||
grad_input = None
|
grad_input = None
|
||||||
|
|
||||||
if ctx.needs_input_grad[1]:
|
# if ctx.needs_input_grad[1]:
|
||||||
if world_size > 1:
|
# if world_size > 1:
|
||||||
handle_grad_weight.wait()
|
# handle_grad_weight.wait()
|
||||||
if grad_bias is not None:
|
# if grad_bias is not None:
|
||||||
handle_grad_bias.wait()
|
# handle_grad_bias.wait()
|
||||||
return grad_input, grad_weight, grad_bias, None, None, None, None
|
return grad_input, grad_weight, grad_bias, None, None, None, None
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -65,6 +65,8 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
hysteresis = grad_scal_cfg.hysteresis
|
hysteresis = grad_scal_cfg.hysteresis
|
||||||
max_scale = grad_scal_cfg.max_scale
|
max_scale = grad_scal_cfg.max_scale
|
||||||
|
|
||||||
|
self._fstp_handler = gpc.config.fstp_handler
|
||||||
|
|
||||||
# Zero related args
|
# Zero related args
|
||||||
reduce_bucket_size = zero_cfg.reduce_bucket_size
|
reduce_bucket_size = zero_cfg.reduce_bucket_size
|
||||||
clip_grad_norm = zero_cfg.clip_grad_norm
|
clip_grad_norm = zero_cfg.clip_grad_norm
|
||||||
|
@ -301,8 +303,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
# NOT IMPORTANT BUT GOOD TO KNOW:
|
# NOT IMPORTANT BUT GOOD TO KNOW:
|
||||||
# args here is not grad, but allow_unreacable and accumulate_grad
|
# args here is not grad, but allow_unreacable and accumulate_grad
|
||||||
def reduce_grad_hook(*args): # pylint: disable=W0613
|
def reduce_grad_hook(*args): # pylint: disable=W0613
|
||||||
if self.skip_grad_reduce is False:
|
reduction_func()
|
||||||
reduction_func()
|
|
||||||
|
|
||||||
accum_grad_obj.register_hook(reduce_grad_hook)
|
accum_grad_obj.register_hook(reduce_grad_hook)
|
||||||
|
|
||||||
|
@ -322,6 +323,20 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
group_id = getattr(param, "group_id")
|
group_id = getattr(param, "group_id")
|
||||||
return tensor_rank == gpc.get_local_rank(self._broadcast_parallel_mode[group_id])
|
return tensor_rank == gpc.get_local_rank(self._broadcast_parallel_mode[group_id])
|
||||||
|
|
||||||
|
def reset_reduce_bucket(self) -> None:
|
||||||
|
for bucket in self._bucket_store:
|
||||||
|
for rank, params in bucket._params.items():
|
||||||
|
for _param in params:
|
||||||
|
if not hasattr(_param, "_fstp_reduce_scatter_str"):
|
||||||
|
continue
|
||||||
|
|
||||||
|
key = getattr(_param, "_fstp_reduce_scatter_str")
|
||||||
|
comm_handle, _grad = self._fstp_handler.reduce_scatter_handlers[key]
|
||||||
|
comm_handle.wait()
|
||||||
|
_param.grad += _grad
|
||||||
|
|
||||||
|
bucket.reset_by_rank(rank)
|
||||||
|
|
||||||
def _store_and_try_reduce_grads_by_bucket(self, param, reduce_rank=None):
|
def _store_and_try_reduce_grads_by_bucket(self, param, reduce_rank=None):
|
||||||
param_size = param.numel()
|
param_size = param.numel()
|
||||||
|
|
||||||
|
@ -332,11 +347,26 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
current_bucket = self._bucket_store[group_id]
|
current_bucket = self._bucket_store[group_id]
|
||||||
|
|
||||||
if current_bucket.num_elements_in_bucket(reduce_rank) + param_size > self._reduce_bucket_size:
|
if current_bucket.num_elements_in_bucket(reduce_rank) + param_size > self._reduce_bucket_size:
|
||||||
self._reduce_grads_stored_in_bucket(current_bucket, reduce_rank, last_bucket=False)
|
# wait reduce scatter communication
|
||||||
|
params = current_bucket.get_param(reduce_rank)
|
||||||
|
for _param in params:
|
||||||
|
if not hasattr(_param, "_fstp_reduce_scatter_str"):
|
||||||
|
continue
|
||||||
|
|
||||||
|
key = getattr(_param, "_fstp_reduce_scatter_str")
|
||||||
|
comm_handle, _grad = self._fstp_handler.reduce_scatter_handlers[key]
|
||||||
|
comm_handle.wait()
|
||||||
|
_param.grad += _grad
|
||||||
|
|
||||||
|
# reduce grad
|
||||||
|
if self.skip_grad_reduce is False:
|
||||||
|
self._reduce_grads_stored_in_bucket(current_bucket, reduce_rank, last_bucket=False)
|
||||||
|
else:
|
||||||
|
current_bucket.reset_by_rank(reduce_rank)
|
||||||
|
|
||||||
# the param must not be reduced to ensure correctness
|
# the param must not be reduced to ensure correctness
|
||||||
is_param_reduced = self._param_store.is_param_reduced(param)
|
is_param_reduced = self._param_store.is_param_reduced(param)
|
||||||
if is_param_reduced:
|
if is_param_reduced and self.skip_grad_reduce is False:
|
||||||
msg = (
|
msg = (
|
||||||
f"Parameter of size ({param.size()}) has already been reduced, "
|
f"Parameter of size ({param.size()}) has already been reduced, "
|
||||||
+ "duplicate reduction will lead to arithmetic incorrectness"
|
+ "duplicate reduction will lead to arithmetic incorrectness"
|
||||||
|
|
|
@ -576,4 +576,8 @@ def record_current_batch_training_metrics(
|
||||||
tgs_list.append(tgs_origin)
|
tgs_list.append(tgs_origin)
|
||||||
if batch_count == gpc.config.data.total_steps - 1:
|
if batch_count == gpc.config.data.total_steps - 1:
|
||||||
print(tgs_list, flush=True)
|
print(tgs_list, flush=True)
|
||||||
|
avg_tgs = sum(tgs_list) / len(tgs_list)
|
||||||
|
for tgs in tgs_list.copy():
|
||||||
|
if abs(tgs - avg_tgs) > 1000:
|
||||||
|
tgs_list.remove(tgs)
|
||||||
print(f"avg_tgs: {sum(tgs_list)/len(tgs_list)}", flush=True)
|
print(f"avg_tgs: {sum(tgs_list)/len(tgs_list)}", flush=True)
|
||||||
|
|
Loading…
Reference in New Issue