Browse Source

[autoparallel] Add metainfo support for F.linear (#1987)

* [fx] metainfo class for auto parallel

* [fx] add unit test for linear metainfo

* [fx] fix bwd param for linear

* [fx] modify unit test

* [fx] modify unit test

* [fx] modify import

* [fx] modify import

* [fx] modify import

* [fx] move meta profiler to auto parallel

* [fx] add conv metainfo class

* [fx] restore profiler

* [fx] restore meta profiler

* [autoparallel] modify unit test

* [fx] modify unit test

* [autoparallel] add batchnorm metainfo class

* [autoparallel] fix batchnorm unit test function declaration

* [fx] restore profiler

* [fx] add relu metainfo class

* [fx] restore profiler

* [autoparallel] modify metainfo input

* [autoparallel] add pooling metainfo

* [autoparallel] add F.linear metainfo generator
pull/2005/head
Boyuan Yao 2 years ago committed by GitHub
parent
commit
6cd784ffee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 18
      colossalai/auto_parallel/meta_profiler/meta_registry/linear.py
  2. 8
      colossalai/auto_parallel/meta_profiler/metainfo.py
  3. 59
      tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_linear_metainfo.py
  4. 12
      tests/test_auto_parallel/test_tensor_shard/test_metainfo/utils.py

18
colossalai/auto_parallel/meta_profiler/meta_registry/linear.py

@ -19,10 +19,13 @@ from ..registry import meta_register
__all__ = ['linear_meta_info']
@meta_register.register(torch.nn.functional.linear)
@meta_register.register(torch.nn.Linear)
def linear_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]:
"""torch.nn.Linear meta info generator
The atens graph of torch.nn.Linear with bias is
"""torch.nn.Linear & torch.nn.functional.linear meta info generator
NOTE: currently we separate the bias part from the biased linear ops, we will consider the memory consumption in add metainfo generator,
but we will hold the bias mechanism in the linear metainfo generator for future use.
graph():
%input_2 : [#users=2] = placeholder[target=placeholder](default=)
%addmm_default : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (None, %input_2, None), kwargs = {})
@ -65,7 +68,7 @@ def linear_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
has_bias: bool = False
input_tensor = next(filter(lambda x: x.type == OperationDataType.ARG, args)).data
output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data
weight_tensor = next(filter(lambda x: x.name == 'weight', args)).data
weight_tensors = [x.data for x in args if x.type == OperationDataType.PARAM]
# process the dimension of input and output
if len(input_tensor.shape) > 2:
@ -76,9 +79,14 @@ def linear_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
output_tensor: torch.Tensor
output_tensor = output_tensor.view(-1, output_tensor.shape[-1])
if len(args) == 4:
bias_tensor = next(filter(lambda x: x.name == 'bias', args)).data
if len(weight_tensors) > 1:
has_bias = True
if len(weight_tensors[0].shape) == 2:
weight_tensor, bias_tensor = weight_tensors
else:
bias_tensor, weight_tensor = weight_tensors
else:
weight_tensor = weight_tensors[0]
if has_bias:
# calculate cost with bias

8
colossalai/auto_parallel/meta_profiler/metainfo.py

@ -92,8 +92,12 @@ class MetaInfo:
Compute meta info based on sharding strategy and the given target function.
"""
assert meta_register.has(self._target.__class__), f'{self._target.__class__} not found in the meta registry'
meta_func = meta_register.get(self._target.__class__)
try:
# module
meta_func = meta_register.get(self._target.__class__)
except:
# function
meta_func = meta_register.get(self._target)
# construct args for meta_func
args = [self.compute_sharded_tensor(k, v) for k, v in self._strategy.sharding_specs.items()]

59
tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_linear_metainfo.py

@ -20,7 +20,17 @@ if torch.__version__ >= '1.12.0':
from colossalai.auto_parallel.meta_profiler import MetaInfo, meta_register
def _linear_module_mem_test(rank, bias, world_size, port):
class MyModule(nn.Module):
def __init__(self, in_features=64, out_features=128):
super().__init__()
self.fc_weight = nn.Parameter(torch.randn(out_features, in_features))
def forward(self, input):
return nn.functional.linear(input, self.fc_weight)
def _linear_module_mem_test(rank, world_size, port):
"""This function is for linear memory test
Test and print real memory cost and estimated, this test will not be executed except with the tag AUTO_PARALLEL
@ -32,7 +42,7 @@ def _linear_module_mem_test(rank, bias, world_size, port):
"""
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model = nn.Sequential(nn.Linear(64, 128, bias=bias)).cuda()
model = nn.Sequential(nn.Linear(64, 128, bias=False)).cuda()
input = torch.rand(8, 8, 16, 64).cuda()
input.requires_grad = True
physical_mesh_id = torch.arange(0, 4)
@ -52,11 +62,50 @@ def _linear_module_mem_test(rank, bias, world_size, port):
@run_on_environment_flag(name='AUTO_PARALLEL')
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_linear_meta_concrete_info_match(bias=False):
def test_linear_module_meta_concrete_info_match():
world_size = 4
run_func_module = partial(_linear_module_mem_test, world_size=world_size, port=free_port())
mp.spawn(run_func_module, nprocs=world_size)
def _linear_function_mem_test(rank, world_size, port):
"""This function is for linear memory test
Test and print real memory cost and estimated, this test will not be executed except with the tag AUTO_PARALLEL
Args:
rank: device rank
bias: indicate whether linear module need bias
world_size: number of devices
port: port for initializing process group
"""
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model = MyModule().cuda()
input = torch.rand(8, 8, 16, 64).cuda()
input.requires_grad = True
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
# memory test
mem_test_for_node_strategy(rank=rank,
model=model,
device_mesh=device_mesh,
node_index=2,
strategy_number=13,
input_args=[input],
meta_arg_names=["input"])
@run_on_environment_flag(name='AUTO_PARALLEL')
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_linear_function_meta_concrete_info_match():
world_size = 4
run_func_module = partial(_linear_module_mem_test, bias=bias, world_size=world_size, port=free_port())
run_func_module = partial(_linear_function_mem_test, world_size=world_size, port=free_port())
mp.spawn(run_func_module, nprocs=world_size)
if __name__ == '__main__':
test_linear_meta_concrete_info_match()
# test_linear_module_meta_concrete_info_match()
test_linear_function_meta_concrete_info_match()

12
tests/test_auto_parallel/test_tensor_shard/test_metainfo/utils.py

@ -7,6 +7,7 @@ from torch.fx import GraphModule
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass
from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDataType
from colossalai.auto_parallel.tensor_shard.solver import SolverOptions, StrategiesConstructor
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.tracer.tracer import ColoTracer
@ -49,8 +50,9 @@ def mem_test_for_node_strategy(rank: int,
# construct the strategy for the output node
placeholder_strategy = list(graph.nodes)[-1].strategies_vector[0]
output_key = next(key for key in target_node.strategies_vector[strategy_index].sharding_specs.keys()
if key in placeholder_strategy.sharding_specs)
if key.type == OperationDataType.OUTPUT)
placeholder_strategy.sharding_specs[output_key] = target_node.strategies_vector[strategy_index].sharding_specs[
output_key]
@ -104,8 +106,12 @@ def mem_test_for_node_strategy(rank: int,
)
# estimated memory
metainfo = MetaInfo(target_node.strategies_vector[strategy_index],
target_node.graph.owning_module.get_submodule(target_node.target))
if target_node.op == "call_module":
metainfo = MetaInfo(target_node.strategies_vector[strategy_index],
target_node.graph.owning_module.get_submodule(target_node.target))
else:
metainfo = MetaInfo(target_node.strategies_vector[strategy_index], target_node.target)
print("estimated memory:")
print(
f"forward activation: {metainfo.memory_cost.fwd.activation / 1024} kb, forward param: {metainfo.memory_cost.fwd.parameter / 1024} kb"

Loading…
Cancel
Save