diff --git a/colossalai/fx/passes/shard_1d_pass.py b/colossalai/fx/passes/shard_1d_pass.py index b58544c02..43c056b60 100644 --- a/colossalai/fx/passes/shard_1d_pass.py +++ b/colossalai/fx/passes/shard_1d_pass.py @@ -4,75 +4,37 @@ from torch.fx.node import Node from torch.fx.passes.split_module import split_module import colossalai -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc +from colossalai.tensor import ColoTensor, TensorSpec, distspec, ProcessGroup, ComputeSpec, ComputePattern -def all_gather_function(input_): - world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D) - rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) - tensor_list = [torch.empty_like(input_) for _ in range(world_size)] - tensor_list[rank] = input_ - group = gpc.get_group(ParallelMode.PARALLEL_1D) - torch.distributed.all_gather(tensor_list, input_, group=group) - output = torch.cat(tensor_list, dim=-1).contiguous() - return output +def weight_split(weight: torch.nn.parameter.Parameter, dim: int) -> torch.nn.parameter.Parameter: + """weight_split + split a nn.Parameter + Args: + weight (torch.nn.parameter.Parameter): a torch Parameter instance + dim (int): the dimension to be sharded along with -def all_reduce_function(input_): - if gpc.get_world_size(ParallelMode.PARALLEL_1D) == 1: - return input_ - torch.distributed.all_reduce(input_, group=gpc.get_group(ParallelMode.PARALLEL_1D)) - return input_ - - -def weight_split(weight, dim): - #TODO: this function will be refactored by using ColoTensor dist_spec when a stable reshaper feature is ready to use. - num_partition = gpc.get_world_size(ParallelMode.TENSOR) - shape = weight.shape - length = shape[dim] // num_partition - sharded_weight_list = [] - for i in range(num_partition): - sharded_weight_list.append(weight.narrow(dim, i * length, length)) - return sharded_weight_list[gpc.get_local_rank(ParallelMode.PARALLEL_1D)] - - -def replace_all_uses_except_replaced(node, replace_node): + Returns: + _type_: _description_ """ - Replace all uses of ``node`` in the Graph with the Node ``replace_node``, - except the user of ``node`` is ``replace_node``. + #TODO: This func temporarily works with no materialization + # Append a Tensor spec to target_module.weight.shard + # Convert to ColoTensor: colo_tensor = ColoTensor.from_torch_tensor(tensor, spec) + # assert isinstance(weight, torch.nn.parameter.Parameter), \ + # f'The type of the input tensor should be torch.nn.parameter' \ + # f'Your Input tensor is {type(weight)}' - Args: + # FIXME() I initialized a PG for this tensor. Only has TP comm group. + # we only consider the TP-only caes. + world_size = torch.distributed.get_world_size() + pg = ProcessGroup(tp_degree=world_size) - replace_node (Node): The node to replace all uses of ``node`` with. - - Returns: - - The list of Nodes on which this change was made. - """ - to_process = list(node.users) - for use_node in to_process: - if use_node == replace_node: - continue - - def may_replace_node(n): - if n == node: - return replace_node - else: - return n - - new_args = map_arg(use_node.args, may_replace_node) - new_kwargs = map_arg(use_node.kwargs, may_replace_node) - use_node._args = new_args - use_node._kwargs = new_kwargs - for old_use in use_node._input_nodes.keys(): - old_use.users.pop(use_node) - use_node._input_nodes = {} - map_arg(use_node._args, lambda n: use_node._input_nodes.setdefault(n)) - map_arg(use_node._kwargs, lambda n: use_node._input_nodes.setdefault(n)) - for new_use in use_node._input_nodes.keys(): - new_use.users.setdefault(use_node) - return to_process + spec = TensorSpec(distspec.shard(pg, [dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) + # As you has constructed a Spec, why not directly convert the tensor to ColoTensor. + # setattr(weight, "fx_attr", spec) + weight.data = ColoTensor(data=weight.data, spec=spec) + return weight def column_shard_linear_pass(gm: torch.fx.GraphModule): @@ -81,14 +43,10 @@ def column_shard_linear_pass(gm: torch.fx.GraphModule): if node.op == "call_module": target_module = node.graph.owning_module.get_submodule(node.target) if isinstance(target_module, torch.nn.Linear): - target_module.weight.data = weight_split(target_module.weight.data, dim=0) + target_module.weight = weight_split(target_module.weight, dim=0) if target_module.bias is not None: target_module.bias.data = weight_split(target_module.bias.data, dim=0) - # inserting communication node after the sharded linear node - with mod_graph.inserting_after(node): - new_node = mod_graph.create_node('call_function', all_gather_function, args=(node,)) - replace_all_uses_except_replaced(node, new_node) gm.recompile() return gm @@ -99,20 +57,7 @@ def row_shard_linear_pass(gm: torch.fx.GraphModule): if node.op == "call_module": target_module = node.graph.owning_module.get_submodule(node.target) if isinstance(target_module, torch.nn.Linear): - target_module.weight.data = weight_split(target_module.weight.data, dim=-1) - - # insert input sharding node before the sharded linear node - with mod_graph.inserting_before(node): - input_node_list = list(node._input_nodes.keys()) - assert len(input_node_list) == 1, 'linear forward must have and only have one input tensor.' - input_node = input_node_list[0] - new_input_node = mod_graph.create_node('call_function', weight_split, args=(input_node, -1)) - replace_all_uses_except_replaced(input_node, new_input_node) - - # inserting communication node after the sharded linear node - with mod_graph.inserting_after(node): - new_node = mod_graph.create_node('call_function', all_reduce_function, args=(node,)) - replace_all_uses_except_replaced(node, new_node) + target_module.weight = weight_split(target_module.weight, dim=-1) gm.recompile() return gm