import torch import torch.nn as nn import operator from colossalai.tensor import ProcessGroup from colossalai.tensor.distspec import shard from colossalai.tensor.compute_spec import ComputePattern, ComputeSpec ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU] ELEMENTWISE_FUNC_OP = [ torch.add, operator.add, torch.abs, torch.cos, torch.exp, torch.mul, operator.mul, operator.floordiv, operator.truediv, operator.neg, torch.multiply, torch.nn.functional.relu, torch.nn.functional.dropout ] def weight_split(weight: torch.nn.parameter.Parameter, dim: int, col_normal: bool) -> torch.nn.parameter.Parameter: """weight_split split a nn.Parameter Args: weight (torch.nn.parameter.Parameter): a torch Parameter instance dim (int): the dimension to be sharded along with col_normal(bool): col shard with gather or not Returns: _type_: _description_ """ if col_normal: setattr(weight, "fx_attr", (dim, "SHARD", "TP", "col_normal")) else: setattr(weight, "fx_attr", (dim, "SHARD", "TP", "col_needs_many_outputs")) return weight def column_shard_linear_pass(gm: torch.fx.GraphModule): # Split all the linear module with column shard. Currently for testing only. mod_graph = gm.graph for node in mod_graph.nodes: if node.op == "call_module": target_module = node.graph.owning_module.get_submodule(node.target) if isinstance(target_module, torch.nn.Linear): target_module.weight = weight_split(target_module.weight, dim=0, col_normal=False) if target_module.bias is not None: target_module.bias.data = weight_split(target_module.bias.data, dim=0, col_normal=False) gm.recompile() return gm def row_shard_linear_pass(gm: torch.fx.GraphModule): # Split all the linear module with row shard. Currently for testing only. mod_graph = gm.graph for node in mod_graph.nodes: if node.op == "call_module": target_module = node.graph.owning_module.get_submodule(node.target) if isinstance(target_module, torch.nn.Linear): target_module.weight = weight_split(target_module.weight, dim=-1, col_normal=False) gm.recompile() return gm def transformer_mlp_pass(graph_module: torch.fx.GraphModule, process_group: ProcessGroup): """ This IR pass checks for transformer MLP like structure and annotate column and row sharding to the linear layers. """ #TODO: Needs to handle special cases, like x = linear(x) + linear(x) graph = graph_module.graph world_size = process_group.world_size() def _traverse_and_annotate(node, start_tracking, annotation_record, world_size): # traverse the graph to look for consecutive linear layers is_linear_module = False if node.op == 'call_module': # look for the linear layer module = node.graph.owning_module.get_submodule(node.target) if isinstance(module, nn.Linear): is_linear_module = True if start_tracking: # when start_tracking = True # it means the first linear has been found and the current module # is the second linear # set the current linear module to be row-sharded annotation_record['row'] = module for shard_type, module in annotation_record.items(): # add row sharding spec if shard_type == 'row': dist_spec = shard(dims=[-1], num_partitions=[world_size]) comp_spec = ComputeSpec(ComputePattern.TP1D) setattr(module.weight, 'pg', process_group) setattr(module.weight, 'dist_spec', dist_spec) setattr(module.weight, 'comp_spec', comp_spec) elif shard_type == 'col': weight_dist_spec = shard(dims=[0], num_partitions=[world_size]) weight_comp_spec = ComputeSpec(ComputePattern.TP1D) weight_comp_spec.output_replicate = False setattr(module.weight, 'pg', process_group) setattr(module.weight, 'dist_spec', weight_dist_spec) setattr(module.weight, 'comp_spec', weight_comp_spec) if module.bias is not None: bias_dist_spec = shard(dims=[0], num_partitions=[world_size]) bias_comp_spec = ComputeSpec(ComputePattern.TP1D) bias_comp_spec.output_replicate = False setattr(module.bias, 'pg', process_group) setattr(module.bias, 'dist_spec', bias_dist_spec) setattr(module.bias, 'comp_spec', bias_comp_spec) start_tracking = False annotation_record.clear() else: # when start tracking = False # it means the current layer is the first linear # set the linear layer to be col-sharded start_tracking = True annotation_record['col'] = module if start_tracking and not is_linear_module: # check against the white list # if non-element wise op is found, we reset the tracking if node.op == 'call_module': module = node.graph.owning_module.get_submodule(node.target) if module.__class__ not in ELEMENTWISE_MODULE_OP: start_tracking = False elif node.op == 'call_function' or node.op == 'call_method': if node.target not in ELEMENTWISE_FUNC_OP: start_tracking = False elif len(node.users.keys()) > 1: start_tracking = False if not start_tracking: annotation_record.clear() # stop tracking for consecutive linear when branch is found # e.g. # out1 = self.linear1(x) # out2 = self.linear2(x) # return out1+out2 next_nodes = list(node.users.keys()) if len(next_nodes) > 1: start_tracking = False annotation_record.clear() # traverse for node in next_nodes: _traverse_and_annotate(node, start_tracking, annotation_record, world_size) placeholder_node = list(graph.nodes)[0] annotate_record = {} _traverse_and_annotate(placeholder_node, False, annotate_record, world_size) return graph_module