mirror of https://github.com/hpcaitech/ColossalAI
105 lines
4.1 KiB
Python
105 lines
4.1 KiB
Python
from colossalai.fx.tracer.meta_patch.patched_module import linear
|
|
import torch
|
|
import torch.nn as nn
|
|
from colossalai.fx import ColoTracer, ColoGraphModule
|
|
from colossalai.auto_parallel.solver.op_handler.dot_handler_v2 import LinearModuleHandler, LinearFunctionHandler
|
|
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
|
from colossalai.device.device_mesh import DeviceMesh
|
|
|
|
|
|
def test_linear_module_handler():
|
|
model = nn.Sequential(nn.Linear(10, 20).to('meta'))
|
|
tracer = ColoTracer()
|
|
graph = tracer.trace(model, meta_args={"input": torch.rand(4, 10).to('meta')})
|
|
gm = ColoGraphModule(model, graph)
|
|
physical_mesh_id = torch.arange(0, 4)
|
|
|
|
print(graph)
|
|
mesh_shape = (2, 2)
|
|
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
|
linear_mod_node = list(graph.nodes)[1]
|
|
strategies_vector = StrategiesVector(linear_mod_node)
|
|
|
|
# build handler
|
|
handler = LinearModuleHandler(node=linear_mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector)
|
|
|
|
# check operation data mapping
|
|
mapping = handler.get_operation_data_mapping()
|
|
|
|
for name, op_data in mapping.items():
|
|
op_data: OperationData
|
|
# make sure they have valid values
|
|
assert op_data.logical_shape is not None
|
|
assert op_data.data is not None
|
|
|
|
assert mapping['input'].name == "input_1"
|
|
assert mapping['input'].data.is_meta
|
|
assert mapping['input'].data.shape == torch.Size([4, 10])
|
|
assert mapping['input'].type == OperationDataType.ARG
|
|
assert mapping['input'].logical_shape == torch.Size([4, 10])
|
|
|
|
assert mapping['other'].name == "weight"
|
|
assert mapping['other'].data.is_meta
|
|
assert mapping['other'].data.shape == torch.Size([20, 10])
|
|
assert mapping['other'].type == OperationDataType.PARAM
|
|
assert mapping['other'].logical_shape == torch.Size([10, 20])
|
|
|
|
assert mapping['bias'].name == "bias"
|
|
assert mapping['bias'].data.is_meta
|
|
assert mapping['bias'].data.shape == torch.Size([20])
|
|
assert mapping['bias'].type == OperationDataType.PARAM
|
|
assert mapping['other'].logical_shape == torch.Size([10, 20])
|
|
|
|
assert mapping['output'].name == "_0"
|
|
assert mapping['output'].data.is_meta
|
|
assert mapping['output'].data.shape == torch.Size([4, 20])
|
|
assert mapping['output'].type == OperationDataType.OUTPUT
|
|
|
|
|
|
def test_linear_function_handler():
|
|
model = nn.Linear(10, 20).to('meta')
|
|
tracer = ColoTracer()
|
|
graph = tracer.trace(model, meta_args={"input": torch.rand(4, 10).to('meta')})
|
|
gm = ColoGraphModule(model, graph)
|
|
physical_mesh_id = torch.arange(0, 4)
|
|
|
|
print(graph)
|
|
mesh_shape = (2, 2)
|
|
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
|
linear_func_node = list(graph.nodes)[3]
|
|
strategies_vector = StrategiesVector(linear_func_node)
|
|
|
|
# build handler
|
|
handler = LinearFunctionHandler(node=linear_func_node, device_mesh=device_mesh, strategies_vector=strategies_vector)
|
|
|
|
# # check operation data mapping
|
|
mapping = handler.get_operation_data_mapping()
|
|
|
|
assert mapping['input'].name == "input_1"
|
|
assert mapping['input'].data.is_meta
|
|
assert mapping['input'].data.shape == torch.Size([4, 10])
|
|
assert mapping['input'].type == OperationDataType.ARG
|
|
assert mapping['input'].logical_shape == torch.Size([4, 10])
|
|
|
|
assert mapping['other'].name == "weight"
|
|
assert mapping['other'].data.is_meta
|
|
assert mapping['other'].data.shape == torch.Size([20, 10])
|
|
assert mapping['other'].type == OperationDataType.ARG
|
|
assert mapping['other'].logical_shape == torch.Size([10, 20])
|
|
|
|
assert mapping['bias'].name == "bias"
|
|
assert mapping['bias'].data.is_meta
|
|
assert mapping['bias'].data.shape == torch.Size([20])
|
|
assert mapping['bias'].type == OperationDataType.ARG
|
|
assert mapping['other'].logical_shape == torch.Size([10, 20])
|
|
|
|
assert mapping['output'].name == "linear"
|
|
assert mapping['output'].data.is_meta
|
|
assert mapping['output'].data.shape == torch.Size([4, 20])
|
|
assert mapping['output'].type == OperationDataType.OUTPUT
|
|
|
|
|
|
if __name__ == '__main__':
|
|
# test_linear_module_handler()
|
|
test_linear_function_handler()
|