mirror of https://github.com/InternLM/InternLM
				
				
				
			Merge branch 'feat/fstp' of https://github.com/yingtongxiong/InternLM into feat/fstp
						commit
						a5c6e457b9
					
				| 
						 | 
				
			
			@ -194,6 +194,7 @@ class NonPipelineScheduler(BaseScheduler):
 | 
			
		|||
            _output, _loss, _moe_loss = self._train_one_batch(
 | 
			
		||||
                _data, _label, engine, forward_only, return_loss, self._grad_accum_size
 | 
			
		||||
            )
 | 
			
		||||
            engine.optimizer.reset_reduce_bucket()
 | 
			
		||||
 | 
			
		||||
            if return_loss:
 | 
			
		||||
                loss += _loss
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -329,6 +329,8 @@ class FSTPAllGatherSyncHandler:
 | 
			
		|||
        self.block_module = dict()  # key: transformer block index; value: {name_index: FSTP module}
 | 
			
		||||
        self.module_name_index = dict()  # key: FSTP module; value: the name in index in self.module_name
 | 
			
		||||
 | 
			
		||||
        self.reduce_scatter_handlers = {}
 | 
			
		||||
 | 
			
		||||
        # just want to share same for loop for ModuleList and Module
 | 
			
		||||
        if not isinstance(model, nn.ModuleList):
 | 
			
		||||
            model = [model]
 | 
			
		||||
| 
						 | 
				
			
			@ -337,16 +339,22 @@ class FSTPAllGatherSyncHandler:
 | 
			
		|||
            if isinstance(_chunk, NaiveAMPModel):
 | 
			
		||||
                _chunk = _chunk.model
 | 
			
		||||
 | 
			
		||||
            for _, children in _chunk.named_children():
 | 
			
		||||
            for _chunk_name, children in _chunk.named_children():
 | 
			
		||||
                if isinstance(children, nn.ModuleList):
 | 
			
		||||
                    for idx, block in enumerate(children):
 | 
			
		||||
                        index = 0
 | 
			
		||||
                        self.block_module[idx] = {}
 | 
			
		||||
                        for _, sub in block.named_children():
 | 
			
		||||
                        for _sub_name, sub in block.named_children():
 | 
			
		||||
                            sub_modules = list(sub.children())
 | 
			
		||||
                            if len(sub_modules) > 0:
 | 
			
		||||
                                for name, child in sub.named_children():
 | 
			
		||||
                                    if isinstance(child, FSTPLinear):
 | 
			
		||||
 | 
			
		||||
                                        _full_name = f"{_chunk_name}.{idx}.{_sub_name}.{name}"
 | 
			
		||||
                                        setattr(child.weight, "_fstp_reduce_scatter_str", f"{_full_name}.weight")
 | 
			
		||||
                                        if child.bias is not None:
 | 
			
		||||
                                            setattr(child.bias, "_fstp_reduce_scatter_str", f"{_full_name}.bias")
 | 
			
		||||
 | 
			
		||||
                                        self.FSTP_modules.append(child)
 | 
			
		||||
                                        self.module_block[child] = idx
 | 
			
		||||
                                        self.block_module[idx][index] = child
 | 
			
		||||
| 
						 | 
				
			
			@ -450,7 +458,9 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
 | 
			
		|||
        self.module_name_index = dict()  # key: FSTP module; value: the name in index in self.module_name
 | 
			
		||||
        self.block_module = dict()  # key: transformer block index; value: {name_index: FSTP module}
 | 
			
		||||
        self.head = []
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
        self.reduce_scatter_handlers = {}
 | 
			
		||||
 | 
			
		||||
        # just want to share same for loop for ModuleList and Module
 | 
			
		||||
        if not isinstance(model, nn.ModuleList):
 | 
			
		||||
            model = [model]
 | 
			
		||||
| 
						 | 
				
			
			@ -459,7 +469,7 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
 | 
			
		|||
            if isinstance(_chunk, NaiveAMPModel):
 | 
			
		||||
                _chunk = _chunk.model
 | 
			
		||||
 | 
			
		||||
            for _, children in _chunk.named_children():
 | 
			
		||||
            for _chunk_name, children in _chunk.named_children():
 | 
			
		||||
                if isinstance(children, nn.ModuleList):
 | 
			
		||||
                    for idx, block in enumerate(children):
 | 
			
		||||
                        index = 0
 | 
			
		||||
| 
						 | 
				
			
			@ -468,7 +478,7 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
 | 
			
		|||
                        self.block_to_index[block] = idx
 | 
			
		||||
                        self.index_to_block[idx] = block
 | 
			
		||||
                        self.index_to_fsdp_modules[idx] = []
 | 
			
		||||
                        for _, sub in block.named_children():
 | 
			
		||||
                        for _sub_name, sub in block.named_children():
 | 
			
		||||
                            sub_modules = list(sub.children())
 | 
			
		||||
                            if len(sub_modules) > 0:
 | 
			
		||||
                                for name, child in sub.named_children():
 | 
			
		||||
| 
						 | 
				
			
			@ -486,6 +496,11 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
 | 
			
		|||
                                        self.index_to_fsdp_modules[idx].append(child)
 | 
			
		||||
                                        self.module_name_index[child] = index
 | 
			
		||||
                                        index = index + 1
 | 
			
		||||
                                        
 | 
			
		||||
                                        _full_name = f"{_chunk_name}.{idx}.{_sub_name}.{name}"
 | 
			
		||||
                                        setattr(child.weight, "_fstp_reduce_scatter_str", f"{_full_name}.weight")
 | 
			
		||||
                                        if child.bias is not None:
 | 
			
		||||
                                            setattr(child.bias, "_fstp_reduce_scatter_str", f"{_full_name}.bias")
 | 
			
		||||
                            else:
 | 
			
		||||
                                continue
 | 
			
		||||
                elif isinstance(children, ScaleColumnParallelLinear):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -324,9 +324,9 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
 | 
			
		|||
            raise RuntimeError("fused_dense only supports matrix dims <= 2M")
 | 
			
		||||
        output = F.linear(total_x, total_weight, total_bias)
 | 
			
		||||
        if ctx.compute_weight_gradient:
 | 
			
		||||
            ctx.save_for_backward(x, weight)
 | 
			
		||||
            ctx.save_for_backward(x, weight, bias)
 | 
			
		||||
        else:
 | 
			
		||||
            ctx.save_for_backward(weight)
 | 
			
		||||
            ctx.save_for_backward(weight, bias)
 | 
			
		||||
        return output if not return_residual else (output, x)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
| 
						 | 
				
			
			@ -340,10 +340,10 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
 | 
			
		|||
        all_gather_handler = ctx.all_gather_handler
 | 
			
		||||
        module = ctx.module
 | 
			
		||||
        if ctx.compute_weight_gradient:
 | 
			
		||||
            x, weight = ctx.saved_tensors
 | 
			
		||||
            x, weight, bias = ctx.saved_tensors
 | 
			
		||||
            total_x = x
 | 
			
		||||
        else:
 | 
			
		||||
            (weight,) = ctx.saved_tensors
 | 
			
		||||
            weight, bias = ctx.saved_tensors
 | 
			
		||||
            total_x = None
 | 
			
		||||
        batch_shape = grad_output.shape[:-1]
 | 
			
		||||
        batch_dim = batch_shape.numel()
 | 
			
		||||
| 
						 | 
				
			
			@ -368,9 +368,15 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
 | 
			
		|||
                total_x.reshape(batch_dim, total_x.shape[-1]), grad_output, ctx.needs_input_grad[2]
 | 
			
		||||
            )
 | 
			
		||||
            if world_size > 1:
 | 
			
		||||
                grad_weight, handle_grad_weight = reduce_scatter_raw(grad_weight, process_group, async_op=True)
 | 
			
		||||
                grad_weight_async, handle_grad_weight = reduce_scatter_raw(grad_weight, process_group, async_op=True)
 | 
			
		||||
                assert hasattr(weight, "_fstp_reduce_scatter_str")
 | 
			
		||||
                all_gather_handler.reduce_scatter_handlers[weight._fstp_reduce_scatter_str] = (handle_grad_weight, grad_weight_async)
 | 
			
		||||
                grad_weight = torch.zeros(grad_weight.shape[0]//torch.distributed.get_world_size(process_group), *grad_weight.shape[1:], dtype=grad_weight.dtype, device=grad_weight.device)
 | 
			
		||||
                if grad_bias is not None:
 | 
			
		||||
                    grad_bias, handle_grad_bias = reduce_scatter_raw(grad_bias, process_group, async_op=True)
 | 
			
		||||
                    grad_bias_async, handle_grad_bias = reduce_scatter_raw(grad_bias, process_group, async_op=True)
 | 
			
		||||
                    assert hasattr(bias, "_fstp_reduce_scatter_str")
 | 
			
		||||
                    all_gather_handler.reduce_scatter_handlers[bias._fstp_reduce_scatter_str] = (handle_grad_bias, grad_bias_async)
 | 
			
		||||
                    grad_bias = torch.zeros(grad_bias.shape[0]//torch.distributed.get_world_size(process_group), *grad_bias.shape[1:], dtype=grad_bias.dtype, device=grad_bias.device)
 | 
			
		||||
        else:
 | 
			
		||||
            grad_weight = None
 | 
			
		||||
            grad_bias = grad_output if ctx.needs_input_grad[2] else None
 | 
			
		||||
| 
						 | 
				
			
			@ -384,11 +390,11 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
 | 
			
		|||
        else:
 | 
			
		||||
            grad_input = None
 | 
			
		||||
 | 
			
		||||
        if ctx.needs_input_grad[1]:
 | 
			
		||||
            if world_size > 1:
 | 
			
		||||
                handle_grad_weight.wait()
 | 
			
		||||
                if grad_bias is not None:
 | 
			
		||||
                    handle_grad_bias.wait()
 | 
			
		||||
        # if ctx.needs_input_grad[1]:
 | 
			
		||||
        #     if world_size > 1:
 | 
			
		||||
        #         handle_grad_weight.wait()
 | 
			
		||||
        #         if grad_bias is not None:
 | 
			
		||||
        #             handle_grad_bias.wait()
 | 
			
		||||
        return grad_input, grad_weight, grad_bias, None, None, None, None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -65,6 +65,8 @@ class HybridZeroOptimizer(BaseOptimizer):
 | 
			
		|||
        hysteresis = grad_scal_cfg.hysteresis
 | 
			
		||||
        max_scale = grad_scal_cfg.max_scale
 | 
			
		||||
 | 
			
		||||
        self._fstp_handler = gpc.config.fstp_handler
 | 
			
		||||
 | 
			
		||||
        # Zero related args
 | 
			
		||||
        reduce_bucket_size = zero_cfg.reduce_bucket_size
 | 
			
		||||
        clip_grad_norm = zero_cfg.clip_grad_norm
 | 
			
		||||
| 
						 | 
				
			
			@ -301,8 +303,7 @@ class HybridZeroOptimizer(BaseOptimizer):
 | 
			
		|||
                        # NOT IMPORTANT BUT GOOD TO KNOW:
 | 
			
		||||
                        # args here is not grad, but allow_unreacable and accumulate_grad
 | 
			
		||||
                        def reduce_grad_hook(*args):  # pylint: disable=W0613
 | 
			
		||||
                            if self.skip_grad_reduce is False:
 | 
			
		||||
                                reduction_func()
 | 
			
		||||
                            reduction_func()
 | 
			
		||||
 | 
			
		||||
                        accum_grad_obj.register_hook(reduce_grad_hook)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -322,6 +323,20 @@ class HybridZeroOptimizer(BaseOptimizer):
 | 
			
		|||
        group_id = getattr(param, "group_id")
 | 
			
		||||
        return tensor_rank == gpc.get_local_rank(self._broadcast_parallel_mode[group_id])
 | 
			
		||||
 | 
			
		||||
    def reset_reduce_bucket(self) -> None:
 | 
			
		||||
        for bucket in self._bucket_store:
 | 
			
		||||
            for rank, params in bucket._params.items():
 | 
			
		||||
                for _param in params:
 | 
			
		||||
                    if not hasattr(_param, "_fstp_reduce_scatter_str"):
 | 
			
		||||
                        continue
 | 
			
		||||
 | 
			
		||||
                    key = getattr(_param, "_fstp_reduce_scatter_str")
 | 
			
		||||
                    comm_handle, _grad = self._fstp_handler.reduce_scatter_handlers[key]
 | 
			
		||||
                    comm_handle.wait()
 | 
			
		||||
                    _param.grad += _grad
 | 
			
		||||
 | 
			
		||||
                bucket.reset_by_rank(rank)
 | 
			
		||||
 | 
			
		||||
    def _store_and_try_reduce_grads_by_bucket(self, param, reduce_rank=None):
 | 
			
		||||
        param_size = param.numel()
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -332,11 +347,26 @@ class HybridZeroOptimizer(BaseOptimizer):
 | 
			
		|||
        current_bucket = self._bucket_store[group_id]
 | 
			
		||||
 | 
			
		||||
        if current_bucket.num_elements_in_bucket(reduce_rank) + param_size > self._reduce_bucket_size:
 | 
			
		||||
            self._reduce_grads_stored_in_bucket(current_bucket, reduce_rank, last_bucket=False)
 | 
			
		||||
            # wait reduce scatter communication
 | 
			
		||||
            params = current_bucket.get_param(reduce_rank)
 | 
			
		||||
            for _param in params:
 | 
			
		||||
                if not hasattr(_param, "_fstp_reduce_scatter_str"):
 | 
			
		||||
                    continue
 | 
			
		||||
 | 
			
		||||
                key = getattr(_param, "_fstp_reduce_scatter_str")
 | 
			
		||||
                comm_handle, _grad = self._fstp_handler.reduce_scatter_handlers[key]
 | 
			
		||||
                comm_handle.wait()
 | 
			
		||||
                _param.grad += _grad
 | 
			
		||||
 | 
			
		||||
            # reduce grad
 | 
			
		||||
            if self.skip_grad_reduce is False:
 | 
			
		||||
                self._reduce_grads_stored_in_bucket(current_bucket, reduce_rank, last_bucket=False)
 | 
			
		||||
            else:
 | 
			
		||||
                current_bucket.reset_by_rank(reduce_rank)
 | 
			
		||||
 | 
			
		||||
        # the param must not be reduced to ensure correctness
 | 
			
		||||
        is_param_reduced = self._param_store.is_param_reduced(param)
 | 
			
		||||
        if is_param_reduced:
 | 
			
		||||
        if is_param_reduced and self.skip_grad_reduce is False:
 | 
			
		||||
            msg = (
 | 
			
		||||
                f"Parameter of size ({param.size()}) has already been reduced, "
 | 
			
		||||
                + "duplicate reduction will lead to arithmetic incorrectness"
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -576,4 +576,8 @@ def record_current_batch_training_metrics(
 | 
			
		|||
            tgs_list.append(tgs_origin)
 | 
			
		||||
        if batch_count == gpc.config.data.total_steps - 1:
 | 
			
		||||
            print(tgs_list, flush=True)
 | 
			
		||||
            avg_tgs = sum(tgs_list) / len(tgs_list)
 | 
			
		||||
            for tgs in tgs_list.copy():
 | 
			
		||||
                if abs(tgs - avg_tgs) > 1000:
 | 
			
		||||
                    tgs_list.remove(tgs)
 | 
			
		||||
            print(f"avg_tgs: {sum(tgs_list)/len(tgs_list)}", flush=True)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue