diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py b/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py index b1ec540d6..4b676d153 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py @@ -2,6 +2,7 @@ from .batch_norm_handler import BatchNormModuleHandler from .binary_elementwise_handler import BinaryElementwiseHandler from .bmm_handler import AddBMMFunctionHandler, BMMFunctionHandler from .conv_handler import ConvFunctionHandler, ConvModuleHandler +from .getatrr_handler import GetattrHandler from .layer_norm_handler import LayerNormModuleHandler from .linear_handler import LinearFunctionHandler, LinearModuleHandler from .matmul_handler import MatMulHandler diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/getatrr_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/getatrr_handler.py new file mode 100644 index 000000000..53addb873 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/getatrr_handler.py @@ -0,0 +1,34 @@ +from typing import Dict, List + +from ..sharding_strategy import OperationData, OperationDataType +from .node_handler import NodeHandler +from .strategy import GetattrGenerator, StrategyGenerator + +__all__ = ['GetattrHandler'] + + +class GetattrHandler(NodeHandler): + """ + A GetattrHandler which deals with the sharding strategies for Getattr Node. + """ + + def get_strategy_generator(self) -> List[StrategyGenerator]: + op_data_mapping = self.get_operation_data_mapping() + generators = [] + generators.append(GetattrGenerator(op_data_mapping, self.device_mesh)) + return generators + + def get_operation_data_mapping(self) -> Dict[str, OperationData]: + # use transposed shape for strategies + # the strategies will be transformed back to its original shape in self.post_process + + # There are only two possible types for get_attr node: + # 1. torch.Tensor(torch.nn.Parameters or torch.nn.Buffers) + # 2. torch.nn.Module + # temporarily, we just support first case in Tracer, so we don't have to worry about + # issue related to the node._meta_data type. + physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data) + + mapping = {"output": physical_output} + + return mapping diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py index 8d9683766..f576b4e4b 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py @@ -6,6 +6,7 @@ from torch.fx.node import Node from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( OperationData, + OperationDataType, ShardingStrategy, StrategiesVector, TrainCycleItem, @@ -49,6 +50,9 @@ class NodeHandler(ABC): for node in self.predecessor_node: node_name = str(node) + # get the current sharding spec generated by this node handler + op_data = strategy.get_op_data_by_name(node_name) + current_sharding_spec = strategy.sharding_specs[op_data] # get the sharding specs for this node generated # in its own node handler @@ -59,10 +63,6 @@ class NodeHandler(ABC): prev_strategy.get_sharding_spec_by_name(node_name) for prev_strategy in prev_strategy_vector ] - # get the current sharding spec generated by this node handler - op_data = strategy.get_op_data_by_name(node_name) - current_sharding_spec = strategy.sharding_specs[op_data] - # create data structrure to store costs if op_data not in resharding_costs: resharding_costs[node] = [] @@ -71,11 +71,14 @@ class NodeHandler(ABC): # compute the resharding cost to switch to the sharding spec generated # by the current node handler for prev_sharding_spec in prev_sharding_specs: - _, _, resharding_cost = shape_consistency_manager.shape_consistency(prev_sharding_spec, - current_sharding_spec) - resharding_cost = TrainCycleItem(fwd=resharding_cost["forward"], - bwd=resharding_cost["backward"], - total=resharding_cost["total"]) + if op_data.type == OperationDataType.PARAM: + resharding_cost = TrainCycleItem(fwd=0, bwd=0, total=0) + else: + _, _, resharding_cost = shape_consistency_manager.shape_consistency( + prev_sharding_spec, current_sharding_spec) + resharding_cost = TrainCycleItem(fwd=resharding_cost["forward"], + bwd=resharding_cost["backward"], + total=resharding_cost["total"]) resharding_costs[node].append(resharding_cost) strategy.resharding_costs = resharding_costs return strategy diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/reshape_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/reshape_handler.py index 402485352..3c4c05786 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/reshape_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/reshape_handler.py @@ -13,6 +13,7 @@ __all__ = ['ReshapeHandler'] @operator_registry.register(torch.reshape) @operator_registry.register(torch.flatten) @operator_registry.register(torch.Tensor.permute) +@operator_registry.register(torch.Tensor.view) @operator_registry.register(torch.nn.AdaptiveAvgPool2d) class ReshapeHandler(NodeHandler): """ diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/__init__.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/__init__.py index 28ee05c0e..954370793 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/__init__.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/__init__.py @@ -1,6 +1,7 @@ from .batch_norm_generator import BatchNormStrategyGenerator from .binary_elementwise_generator import BinaryElementwiseStrategyGenerator from .conv_strategy_generator import ConvStrategyGenerator +from .getattr_generator import GetattrGenerator from .getitem_generator import GetItemStrategyGenerator, TensorStrategyGenerator, TensorTupleStrategyGenerator from .layer_norm_generator import LayerNormGenerator from .matmul_strategy_generator import ( @@ -22,5 +23,5 @@ __all__ = [ 'BatchedMatMulStrategyGenerator', 'ConvStrategyGenerator', 'UnaryElementwiseGenerator', 'BatchNormStrategyGenerator', 'GetItemStrategyGenerator', 'TensorStrategyGenerator', 'TensorTupleStrategyGenerator', 'LayerNormGenerator', 'ReshapeGenerator', 'PlaceholderGenerator', 'OutputGenerator', 'WhereGenerator', - 'ReshapeGenerator', 'NormalPoolStrategyGenerator', 'BinaryElementwiseStrategyGenerator' + 'ReshapeGenerator', 'NormalPoolStrategyGenerator', 'BinaryElementwiseStrategyGenerator', 'GetattrGenerator' ] diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getattr_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getattr_generator.py new file mode 100644 index 000000000..753ab1726 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getattr_generator.py @@ -0,0 +1,53 @@ +from typing import List + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem + +from .strategy_generator import StrategyGenerator + +__all__ = ['GetattrGenerator'] + + +class GetattrGenerator(StrategyGenerator): + """ + PlaceholderGenerator is a generic class to generate strategies for placeholder node. + """ + + def validate(self) -> bool: + return super().validate() + + def update_compute_cost(self, strategy: ShardingStrategy): + compute_cost = TrainCycleItem(fwd=10, bwd=10, total=20) + strategy.compute_cost = compute_cost + + def update_memory_cost(self, strategy: ShardingStrategy): + ''' + Compute the memory cost per device with this specific strategy. + ''' + forward_size_mapping = {'output': self._compute_size_in_bytes(strategy, "output")} + + # compute fwd cost incurred + # fwd_cost = output + fwd_activation_cost = sum([v for k, v in forward_size_mapping.items()]) + fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=0) + + bwd_mem_cost = MemoryCost(activation=0, parameter=0) + + # compute total cost + total_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=0) + memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) + strategy.memory_cost = memory_cost + + def collate_strategies(self) -> List[ShardingStrategy]: + dim_partition_dict_mapping = { + "output": {}, + } + communication_action_mapping = {} + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + name = 'Replica Attribute' + + strategy = self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + return [strategy] diff --git a/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py b/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py index 57d5dfa79..48035e6b8 100644 --- a/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py +++ b/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py @@ -6,9 +6,10 @@ from typing import Dict, List import torch from torch.fx import Graph, Node -from colossalai.auto_parallel.tensor_shard.node_handler import (OuputHandler, PlacehodlerHandler, operator_registry) -from colossalai.auto_parallel.tensor_shard.sharding_strategy import (ShardingStrategy, StrategiesVector) -from colossalai.auto_parallel.tensor_shard.utils import (generate_resharding_costs, generate_sharding_spec) +from colossalai.auto_parallel.tensor_shard.node_handler import OuputHandler, PlacehodlerHandler, operator_registry +from colossalai.auto_parallel.tensor_shard.node_handler.getatrr_handler import GetattrHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ShardingStrategy, StrategiesVector +from colossalai.auto_parallel.tensor_shard.utils import generate_resharding_costs, generate_sharding_spec from colossalai.device.device_mesh import DeviceMesh from colossalai.tensor.shape_consistency import ShapeConsistencyManager from colossalai.tensor.sharding_spec import ShardingSpec @@ -71,25 +72,8 @@ class StrategiesConstructor: # get_attr node if node.op == 'get_attr': - # Same as placeholder nodes, if solver_options.fast is True, we just let them in - # fully replicate status, then strategies of following node will be treated equally due - # to replicate status has no resharding cost to other status. At the same time, the searching - # space is smaller than enumerating all the possible sharding spec for the get_attr node. - # Otherwise, all the possible sharding spec for the get_attr node will be enumerated. - if self.solver_options.fast: - # create sharding strategy for get_attr - name = 'Replica Attribute' - dim_partition_dict = {} - output_sharding_spec = generate_sharding_spec(node, self.device_mesh, dim_partition_dict) - # TODO: use meta_info_prop to profile memory cost - memory_cost = 0 - sharding_strategy_attribute = ShardingStrategy(name, output_sharding_spec, memory_cost=memory_cost) - strategies_vector.append(sharding_strategy_attribute) - - # # get_attr node - # elif node.op == 'get_attr': - # # TODO: implement getattr node handler - # pass + getattr_handler = GetattrHandler(node, self.device_mesh, strategies_vector) + getattr_handler.register_strategy() # call_module node elif node.op == 'call_module': diff --git a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/conv.py b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/conv.py index e6d7be820..fb8f46b5e 100644 --- a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/conv.py +++ b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/conv.py @@ -20,6 +20,7 @@ class BiasAdditionConv(BiasAdditionModule): if hasattr(conv_module, attr_name): non_bias_kwargs[attr_name] = getattr(conv_module, attr_name) if conv_module.padding_mode != "zeros": + #TODO: non zeros mode requires some extra processing for input conv_type = type(conv_module) if conv_type == "torch.nn.Conv1d": padding_element = _single(0) diff --git a/colossalai/fx/tracer/tracer.py b/colossalai/fx/tracer/tracer.py index ca1ded09c..6295523b8 100644 --- a/colossalai/fx/tracer/tracer.py +++ b/colossalai/fx/tracer/tracer.py @@ -93,17 +93,18 @@ class ColoTracer(Tracer): origin_arguments = (kind, target, args, kwargs, name, type_expr, proxy_factory_fn) # dispatch the arguments generator depending on the kind and target in origin arguments. args_metas, _ = extract_meta(*args, **kwargs) + handle = None if kind == "call_function": if bias_addition_function.has(target): - return bias_addition_function.get(target)(self, target, args, kwargs) + handle = bias_addition_function.get(target)(self, target, args, kwargs) elif bias_addition_function.has(target.__name__): # use name for some builtin op like @ (matmul) - return bias_addition_function.get(target.__name__)(self, target, args, kwargs) + handle = bias_addition_function.get(target.__name__)(self, target, args, kwargs) elif kind == "call_method": method = getattr(args_metas[0].__class__, target) if bias_addition_function.has(method): - return bias_addition_function.get(method)(self, target, args, kwargs) + handle = bias_addition_function.get(method)(self, target, args, kwargs) elif kind == "call_module": if not hasattr(self, "orig_forward"): @@ -115,10 +116,12 @@ class ColoTracer(Tracer): if bias_addition_module.has(mod_type) and mod.bias is not None: function_to_substitute = module_to_func_dict[mod_type] handle = bias_addition_module.get(mod_type)(self, target, args, kwargs, function_to_substitute) - return handle.generate() finally: self._disable_module_getattr = False + if handle is not None: + return handle.generate() + # create nodes using patched arguments proxy = super().create_proxy(*origin_arguments) proxy: ColoProxy @@ -254,7 +257,9 @@ class ColoTracer(Tracer): atoms = target.split(".") for atom in atoms: attr_itr = getattr(attr_itr, atom) - if isinstance(attr_itr, torch.Tensor): + if isinstance(attr_itr, torch.nn.parameter.Parameter): + meta_out = torch.nn.Parameter(attr_itr.to(device="meta")) + elif isinstance(attr_itr, torch.Tensor): meta_out = attr_itr.to(device="meta") else: meta_out = attr_itr diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getattr_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getattr_handler.py new file mode 100644 index 000000000..ad093c2ed --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getattr_handler.py @@ -0,0 +1,58 @@ +import torch +import torch.nn as nn + +from colossalai.auto_parallel.tensor_shard.node_handler.getatrr_handler import GetattrHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx import ColoGraphModule, ColoTracer + + +class GetattrModel(nn.Module): + + def __init__(self): + super().__init__() + self.conv = nn.Conv2d(4, 16, 3, padding=1, bias=False) + + def forward(self, input): + weight = self.conv.weight + return weight + + +def test_getattr_handler(): + model = GetattrModel() + tracer = ColoTracer() + # graph(): + # %input_1 : torch.Tensor [#users=0] = placeholder[target=input] + # %conv_weight : [#users=1] = get_attr[target=conv.weight] + # return conv_weight + graph = tracer.trace(model, meta_args={'input': torch.rand(4, 4, 64, 64).to('meta')}) + gm = ColoGraphModule(model, graph) + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + getattr_node = list(graph.nodes)[1] + getattr_strategies_vector = StrategiesVector(getattr_node) + + # build handler + getattr_handler = GetattrHandler(node=getattr_node, + device_mesh=device_mesh, + strategies_vector=getattr_strategies_vector) + + getattr_handler.register_strategy(compute_resharding_cost=False) + # check operation data mapping + mapping = getattr_handler.get_operation_data_mapping() + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.data is not None + + assert mapping['output'].name == "conv_weight" + assert mapping['output'].data.shape == torch.Size((16, 4, 3, 3)) + assert mapping['output'].type == OperationDataType.OUTPUT + strategy_name_list = [val.name for val in getattr_handler.strategies_vector] + assert "Replica Attribute" in strategy_name_list + + +if __name__ == '__main__': + test_getattr_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_param_resharding_cost.py b/tests/test_auto_parallel/test_tensor_shard/test_param_resharding_cost.py new file mode 100644 index 000000000..b67641f61 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_param_resharding_cost.py @@ -0,0 +1,128 @@ +import torch + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDataType +from colossalai.auto_parallel.tensor_shard.solver import ( + CostGraph, + GraphAnalyser, + Solver, + SolverOptions, + StrategiesConstructor, +) +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx import ColoGraphModule, ColoTracer + + +def _param_resharding_cost_assertion(node): + for strategy in node.strategies_vector: + for prev_node, resharding_cost in strategy.resharding_costs.items(): + if strategy.get_op_data_by_name(str(prev_node)).type == OperationDataType.PARAM: + for cost in resharding_cost: + assert cost.fwd == 0 + assert cost.bwd == 0 + assert cost.total == 0 + + +class LinearModel(torch.nn.Module): + + def __init__(self, in_features, out_features): + super().__init__() + self.linear = torch.nn.Linear(in_features, out_features) + + def forward(self, x): + x = self.linear(x) + x = x * 2 + + return x + + +class ConvModel(torch.nn.Module): + + def __init__(self, in_channels, out_channels, kernel_size, bias=True): + super().__init__() + self.conv = torch.nn.Conv2d(in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + bias=bias) + + def forward(self, x): + x = self.conv(x) + x = x * 2 + + return x + + +def test_linear_module(): + model = LinearModel(4, 8) + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + # [[0, 1] + # [2, 3]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + tracer = ColoTracer() + # graph(): + # %x : torch.Tensor [#users=1] = placeholder[target=x] + # %linear_weight : [#users=1] = get_attr[target=linear.weight] + # %linear_bias : [#users=1] = get_attr[target=linear.bias] + # %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%x, %linear_weight), kwargs = {}) + # %add : [#users=1] = call_function[target=operator.add](args = (%linear, %linear_bias), kwargs = {}) + # %mul : [#users=1] = call_function[target=operator.mul](args = (%add, 2), kwargs = {}) + # return mul + graph = tracer.trace(root=model, meta_args={'x': torch.rand(4, 4).to('meta')}) + # def forward(self, x : torch.Tensor): + # linear_weight = self.linear.weight + # linear_bias = self.linear.bias + # linear = torch._C._nn.linear(x, linear_weight); x = linear_weight = None + # add = linear + linear_bias; linear = linear_bias = None + # mul = add * 2; add = None + # return mul + gm = ColoGraphModule(model, graph) + gm.recompile() + node_list = list(graph.nodes) + + solver_options = SolverOptions(fast=True) + strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) + strategies_constructor.build_strategies_and_cost() + linear_node = node_list[3] + _param_resharding_cost_assertion(linear_node) + + +def test_conv_module(): + model = ConvModel(3, 6, 2) + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + # [[0, 1] + # [2, 3]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + tracer = ColoTracer() + # graph(): + # %x : torch.Tensor [#users=1] = placeholder[target=x] + # %conv_weight : [#users=1] = get_attr[target=conv.weight] + # %conv_bias : [#users=1] = get_attr[target=conv.bias] + # %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%x, %conv_weight), kwargs = {}) + # %view : [#users=1] = call_method[target=view](args = (%conv_bias, [1, -1, 1, 1]), kwargs = {}) + # %add : [#users=1] = call_function[target=operator.add](args = (%conv2d, %view), kwargs = {}) + # %mul : [#users=1] = call_function[target=operator.mul](args = (%add, 2), kwargs = {}) + # return mul + graph = tracer.trace(root=model, meta_args={'x': torch.rand(4, 3, 64, 64).to('meta')}) + # def forward(self, x : torch.Tensor): + # conv_weight = self.conv.weight + # conv_bias = self.conv.bias + # conv2d = torch.conv2d(x, conv_weight); x = conv_weight = None + # view = conv_bias.view([1, -1, 1, 1]); conv_bias = None + # add = conv2d + view; conv2d = view = None + # mul = add * 2; add = None + # return mul + gm = ColoGraphModule(model, graph) + + gm.recompile() + node_list = list(graph.nodes) + conv_node = node_list[3] + solver_options = SolverOptions(fast=True) + strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) + strategies_constructor.build_strategies_and_cost() + _param_resharding_cost_assertion(conv_node) + + +if __name__ == '__main__': + test_linear_module() + test_conv_module()