from colossalai.fx.tracer.meta_patch.patched_module import linear import torch import torch.nn as nn from colossalai.fx import ColoTracer, ColoGraphModule from colossalai.auto_parallel.solver.op_handler.dot_handler_v2 import LinearModuleHandler, LinearFunctionHandler from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.device.device_mesh import DeviceMesh def test_linear_module_handler(): model = nn.Sequential(nn.Linear(10, 20).to('meta')) tracer = ColoTracer() graph = tracer.trace(model, meta_args={"input": torch.rand(4, 10).to('meta')}) gm = ColoGraphModule(model, graph) physical_mesh_id = torch.arange(0, 4) print(graph) mesh_shape = (2, 2) device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) linear_mod_node = list(graph.nodes)[1] strategies_vector = StrategiesVector(linear_mod_node) # build handler handler = LinearModuleHandler(node=linear_mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector) # check operation data mapping mapping = handler.get_operation_data_mapping() for name, op_data in mapping.items(): op_data: OperationData # make sure they have valid values assert op_data.logical_shape is not None assert op_data.data is not None assert mapping['input'].name == "input_1" assert mapping['input'].data.is_meta assert mapping['input'].data.shape == torch.Size([4, 10]) assert mapping['input'].type == OperationDataType.ARG assert mapping['input'].logical_shape == torch.Size([4, 10]) assert mapping['other'].name == "weight" assert mapping['other'].data.is_meta assert mapping['other'].data.shape == torch.Size([20, 10]) assert mapping['other'].type == OperationDataType.PARAM assert mapping['other'].logical_shape == torch.Size([10, 20]) assert mapping['bias'].name == "bias" assert mapping['bias'].data.is_meta assert mapping['bias'].data.shape == torch.Size([20]) assert mapping['bias'].type == OperationDataType.PARAM assert mapping['other'].logical_shape == torch.Size([10, 20]) assert mapping['output'].name == "_0" assert mapping['output'].data.is_meta assert mapping['output'].data.shape == torch.Size([4, 20]) assert mapping['output'].type == OperationDataType.OUTPUT def test_linear_function_handler(): model = nn.Linear(10, 20).to('meta') tracer = ColoTracer() graph = tracer.trace(model, meta_args={"input": torch.rand(4, 10).to('meta')}) gm = ColoGraphModule(model, graph) physical_mesh_id = torch.arange(0, 4) print(graph) mesh_shape = (2, 2) device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) linear_func_node = list(graph.nodes)[3] strategies_vector = StrategiesVector(linear_func_node) # build handler handler = LinearFunctionHandler(node=linear_func_node, device_mesh=device_mesh, strategies_vector=strategies_vector) # # check operation data mapping mapping = handler.get_operation_data_mapping() assert mapping['input'].name == "input_1" assert mapping['input'].data.is_meta assert mapping['input'].data.shape == torch.Size([4, 10]) assert mapping['input'].type == OperationDataType.ARG assert mapping['input'].logical_shape == torch.Size([4, 10]) assert mapping['other'].name == "weight" assert mapping['other'].data.is_meta assert mapping['other'].data.shape == torch.Size([20, 10]) assert mapping['other'].type == OperationDataType.ARG assert mapping['other'].logical_shape == torch.Size([10, 20]) assert mapping['bias'].name == "bias" assert mapping['bias'].data.is_meta assert mapping['bias'].data.shape == torch.Size([20]) assert mapping['bias'].type == OperationDataType.ARG assert mapping['other'].logical_shape == torch.Size([10, 20]) assert mapping['output'].name == "linear" assert mapping['output'].data.is_meta assert mapping['output'].data.shape == torch.Size([4, 20]) assert mapping['output'].type == OperationDataType.OUTPUT if __name__ == '__main__': # test_linear_module_handler() test_linear_function_handler()