[autoparallel] Add metainfo support for F.linear (#1987)

* [fx] metainfo class for auto parallel

* [fx] add unit test for linear metainfo

* [fx] fix bwd param for linear

* [fx] modify unit test

* [fx] modify unit test

* [fx] modify import

* [fx] modify import

* [fx] modify import

* [fx] move meta profiler to auto parallel

* [fx] add conv metainfo class

* [fx] restore profiler

* [fx] restore meta profiler

* [autoparallel] modify unit test

* [fx] modify unit test

* [autoparallel] add batchnorm metainfo class

* [autoparallel] fix batchnorm unit test function declaration

* [fx] restore profiler

* [fx] add relu metainfo class

* [fx] restore profiler

* [autoparallel] modify metainfo input

* [autoparallel] add pooling metainfo

* [autoparallel] add F.linear metainfo generator
pull/2005/head
Boyuan Yao 2 years ago committed by GitHub
parent 2edbef13cc
commit 6cd784ffee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -19,10 +19,13 @@ from ..registry import meta_register
__all__ = ['linear_meta_info'] __all__ = ['linear_meta_info']
@meta_register.register(torch.nn.functional.linear)
@meta_register.register(torch.nn.Linear) @meta_register.register(torch.nn.Linear)
def linear_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: def linear_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]:
"""torch.nn.Linear meta info generator """torch.nn.Linear & torch.nn.functional.linear meta info generator
The atens graph of torch.nn.Linear with bias is NOTE: currently we separate the bias part from the biased linear ops, we will consider the memory consumption in add metainfo generator,
but we will hold the bias mechanism in the linear metainfo generator for future use.
graph(): graph():
%input_2 : [#users=2] = placeholder[target=placeholder](default=) %input_2 : [#users=2] = placeholder[target=placeholder](default=)
%addmm_default : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (None, %input_2, None), kwargs = {}) %addmm_default : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (None, %input_2, None), kwargs = {})
@ -65,7 +68,7 @@ def linear_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
has_bias: bool = False has_bias: bool = False
input_tensor = next(filter(lambda x: x.type == OperationDataType.ARG, args)).data input_tensor = next(filter(lambda x: x.type == OperationDataType.ARG, args)).data
output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data
weight_tensor = next(filter(lambda x: x.name == 'weight', args)).data weight_tensors = [x.data for x in args if x.type == OperationDataType.PARAM]
# process the dimension of input and output # process the dimension of input and output
if len(input_tensor.shape) > 2: if len(input_tensor.shape) > 2:
@ -76,9 +79,14 @@ def linear_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
output_tensor: torch.Tensor output_tensor: torch.Tensor
output_tensor = output_tensor.view(-1, output_tensor.shape[-1]) output_tensor = output_tensor.view(-1, output_tensor.shape[-1])
if len(args) == 4: if len(weight_tensors) > 1:
bias_tensor = next(filter(lambda x: x.name == 'bias', args)).data
has_bias = True has_bias = True
if len(weight_tensors[0].shape) == 2:
weight_tensor, bias_tensor = weight_tensors
else:
bias_tensor, weight_tensor = weight_tensors
else:
weight_tensor = weight_tensors[0]
if has_bias: if has_bias:
# calculate cost with bias # calculate cost with bias

@ -92,8 +92,12 @@ class MetaInfo:
Compute meta info based on sharding strategy and the given target function. Compute meta info based on sharding strategy and the given target function.
""" """
assert meta_register.has(self._target.__class__), f'{self._target.__class__} not found in the meta registry' try:
meta_func = meta_register.get(self._target.__class__) # module
meta_func = meta_register.get(self._target.__class__)
except:
# function
meta_func = meta_register.get(self._target)
# construct args for meta_func # construct args for meta_func
args = [self.compute_sharded_tensor(k, v) for k, v in self._strategy.sharding_specs.items()] args = [self.compute_sharded_tensor(k, v) for k, v in self._strategy.sharding_specs.items()]

@ -20,7 +20,17 @@ if torch.__version__ >= '1.12.0':
from colossalai.auto_parallel.meta_profiler import MetaInfo, meta_register from colossalai.auto_parallel.meta_profiler import MetaInfo, meta_register
def _linear_module_mem_test(rank, bias, world_size, port): class MyModule(nn.Module):
def __init__(self, in_features=64, out_features=128):
super().__init__()
self.fc_weight = nn.Parameter(torch.randn(out_features, in_features))
def forward(self, input):
return nn.functional.linear(input, self.fc_weight)
def _linear_module_mem_test(rank, world_size, port):
"""This function is for linear memory test """This function is for linear memory test
Test and print real memory cost and estimated, this test will not be executed except with the tag AUTO_PARALLEL Test and print real memory cost and estimated, this test will not be executed except with the tag AUTO_PARALLEL
@ -32,7 +42,7 @@ def _linear_module_mem_test(rank, bias, world_size, port):
""" """
disable_existing_loggers() disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model = nn.Sequential(nn.Linear(64, 128, bias=bias)).cuda() model = nn.Sequential(nn.Linear(64, 128, bias=False)).cuda()
input = torch.rand(8, 8, 16, 64).cuda() input = torch.rand(8, 8, 16, 64).cuda()
input.requires_grad = True input.requires_grad = True
physical_mesh_id = torch.arange(0, 4) physical_mesh_id = torch.arange(0, 4)
@ -52,11 +62,50 @@ def _linear_module_mem_test(rank, bias, world_size, port):
@run_on_environment_flag(name='AUTO_PARALLEL') @run_on_environment_flag(name='AUTO_PARALLEL')
@pytest.mark.dist @pytest.mark.dist
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_linear_meta_concrete_info_match(bias=False): def test_linear_module_meta_concrete_info_match():
world_size = 4
run_func_module = partial(_linear_module_mem_test, world_size=world_size, port=free_port())
mp.spawn(run_func_module, nprocs=world_size)
def _linear_function_mem_test(rank, world_size, port):
"""This function is for linear memory test
Test and print real memory cost and estimated, this test will not be executed except with the tag AUTO_PARALLEL
Args:
rank: device rank
bias: indicate whether linear module need bias
world_size: number of devices
port: port for initializing process group
"""
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model = MyModule().cuda()
input = torch.rand(8, 8, 16, 64).cuda()
input.requires_grad = True
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
# memory test
mem_test_for_node_strategy(rank=rank,
model=model,
device_mesh=device_mesh,
node_index=2,
strategy_number=13,
input_args=[input],
meta_arg_names=["input"])
@run_on_environment_flag(name='AUTO_PARALLEL')
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_linear_function_meta_concrete_info_match():
world_size = 4 world_size = 4
run_func_module = partial(_linear_module_mem_test, bias=bias, world_size=world_size, port=free_port()) run_func_module = partial(_linear_function_mem_test, world_size=world_size, port=free_port())
mp.spawn(run_func_module, nprocs=world_size) mp.spawn(run_func_module, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
test_linear_meta_concrete_info_match() # test_linear_module_meta_concrete_info_match()
test_linear_function_meta_concrete_info_match()

@ -7,6 +7,7 @@ from torch.fx import GraphModule
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass
from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDataType
from colossalai.auto_parallel.tensor_shard.solver import SolverOptions, StrategiesConstructor from colossalai.auto_parallel.tensor_shard.solver import SolverOptions, StrategiesConstructor
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.tracer.tracer import ColoTracer from colossalai.fx.tracer.tracer import ColoTracer
@ -49,8 +50,9 @@ def mem_test_for_node_strategy(rank: int,
# construct the strategy for the output node # construct the strategy for the output node
placeholder_strategy = list(graph.nodes)[-1].strategies_vector[0] placeholder_strategy = list(graph.nodes)[-1].strategies_vector[0]
output_key = next(key for key in target_node.strategies_vector[strategy_index].sharding_specs.keys() output_key = next(key for key in target_node.strategies_vector[strategy_index].sharding_specs.keys()
if key in placeholder_strategy.sharding_specs) if key.type == OperationDataType.OUTPUT)
placeholder_strategy.sharding_specs[output_key] = target_node.strategies_vector[strategy_index].sharding_specs[ placeholder_strategy.sharding_specs[output_key] = target_node.strategies_vector[strategy_index].sharding_specs[
output_key] output_key]
@ -104,8 +106,12 @@ def mem_test_for_node_strategy(rank: int,
) )
# estimated memory # estimated memory
metainfo = MetaInfo(target_node.strategies_vector[strategy_index], if target_node.op == "call_module":
target_node.graph.owning_module.get_submodule(target_node.target)) metainfo = MetaInfo(target_node.strategies_vector[strategy_index],
target_node.graph.owning_module.get_submodule(target_node.target))
else:
metainfo = MetaInfo(target_node.strategies_vector[strategy_index], target_node.target)
print("estimated memory:") print("estimated memory:")
print( print(
f"forward activation: {metainfo.memory_cost.fwd.activation / 1024} kb, forward param: {metainfo.memory_cost.fwd.parameter / 1024} kb" f"forward activation: {metainfo.memory_cost.fwd.activation / 1024} kb, forward param: {metainfo.memory_cost.fwd.parameter / 1024} kb"

Loading…
Cancel
Save