mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] Add metainfo support for F.linear (#1987)
* [fx] metainfo class for auto parallel * [fx] add unit test for linear metainfo * [fx] fix bwd param for linear * [fx] modify unit test * [fx] modify unit test * [fx] modify import * [fx] modify import * [fx] modify import * [fx] move meta profiler to auto parallel * [fx] add conv metainfo class * [fx] restore profiler * [fx] restore meta profiler * [autoparallel] modify unit test * [fx] modify unit test * [autoparallel] add batchnorm metainfo class * [autoparallel] fix batchnorm unit test function declaration * [fx] restore profiler * [fx] add relu metainfo class * [fx] restore profiler * [autoparallel] modify metainfo input * [autoparallel] add pooling metainfo * [autoparallel] add F.linear metainfo generatorpull/2005/head
parent
2edbef13cc
commit
6cd784ffee
|
@ -19,10 +19,13 @@ from ..registry import meta_register
|
|||
__all__ = ['linear_meta_info']
|
||||
|
||||
|
||||
@meta_register.register(torch.nn.functional.linear)
|
||||
@meta_register.register(torch.nn.Linear)
|
||||
def linear_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]:
|
||||
"""torch.nn.Linear meta info generator
|
||||
The atens graph of torch.nn.Linear with bias is
|
||||
"""torch.nn.Linear & torch.nn.functional.linear meta info generator
|
||||
NOTE: currently we separate the bias part from the biased linear ops, we will consider the memory consumption in add metainfo generator,
|
||||
but we will hold the bias mechanism in the linear metainfo generator for future use.
|
||||
|
||||
graph():
|
||||
%input_2 : [#users=2] = placeholder[target=placeholder](default=)
|
||||
%addmm_default : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (None, %input_2, None), kwargs = {})
|
||||
|
@ -65,7 +68,7 @@ def linear_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
|
|||
has_bias: bool = False
|
||||
input_tensor = next(filter(lambda x: x.type == OperationDataType.ARG, args)).data
|
||||
output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data
|
||||
weight_tensor = next(filter(lambda x: x.name == 'weight', args)).data
|
||||
weight_tensors = [x.data for x in args if x.type == OperationDataType.PARAM]
|
||||
|
||||
# process the dimension of input and output
|
||||
if len(input_tensor.shape) > 2:
|
||||
|
@ -76,9 +79,14 @@ def linear_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
|
|||
output_tensor: torch.Tensor
|
||||
output_tensor = output_tensor.view(-1, output_tensor.shape[-1])
|
||||
|
||||
if len(args) == 4:
|
||||
bias_tensor = next(filter(lambda x: x.name == 'bias', args)).data
|
||||
if len(weight_tensors) > 1:
|
||||
has_bias = True
|
||||
if len(weight_tensors[0].shape) == 2:
|
||||
weight_tensor, bias_tensor = weight_tensors
|
||||
else:
|
||||
bias_tensor, weight_tensor = weight_tensors
|
||||
else:
|
||||
weight_tensor = weight_tensors[0]
|
||||
|
||||
if has_bias:
|
||||
# calculate cost with bias
|
||||
|
|
|
@ -92,8 +92,12 @@ class MetaInfo:
|
|||
Compute meta info based on sharding strategy and the given target function.
|
||||
"""
|
||||
|
||||
assert meta_register.has(self._target.__class__), f'{self._target.__class__} not found in the meta registry'
|
||||
meta_func = meta_register.get(self._target.__class__)
|
||||
try:
|
||||
# module
|
||||
meta_func = meta_register.get(self._target.__class__)
|
||||
except:
|
||||
# function
|
||||
meta_func = meta_register.get(self._target)
|
||||
|
||||
# construct args for meta_func
|
||||
args = [self.compute_sharded_tensor(k, v) for k, v in self._strategy.sharding_specs.items()]
|
||||
|
|
|
@ -20,7 +20,17 @@ if torch.__version__ >= '1.12.0':
|
|||
from colossalai.auto_parallel.meta_profiler import MetaInfo, meta_register
|
||||
|
||||
|
||||
def _linear_module_mem_test(rank, bias, world_size, port):
|
||||
class MyModule(nn.Module):
|
||||
|
||||
def __init__(self, in_features=64, out_features=128):
|
||||
super().__init__()
|
||||
self.fc_weight = nn.Parameter(torch.randn(out_features, in_features))
|
||||
|
||||
def forward(self, input):
|
||||
return nn.functional.linear(input, self.fc_weight)
|
||||
|
||||
|
||||
def _linear_module_mem_test(rank, world_size, port):
|
||||
"""This function is for linear memory test
|
||||
Test and print real memory cost and estimated, this test will not be executed except with the tag AUTO_PARALLEL
|
||||
|
||||
|
@ -32,7 +42,7 @@ def _linear_module_mem_test(rank, bias, world_size, port):
|
|||
"""
|
||||
disable_existing_loggers()
|
||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
model = nn.Sequential(nn.Linear(64, 128, bias=bias)).cuda()
|
||||
model = nn.Sequential(nn.Linear(64, 128, bias=False)).cuda()
|
||||
input = torch.rand(8, 8, 16, 64).cuda()
|
||||
input.requires_grad = True
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
|
@ -52,11 +62,50 @@ def _linear_module_mem_test(rank, bias, world_size, port):
|
|||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_linear_meta_concrete_info_match(bias=False):
|
||||
def test_linear_module_meta_concrete_info_match():
|
||||
world_size = 4
|
||||
run_func_module = partial(_linear_module_mem_test, bias=bias, world_size=world_size, port=free_port())
|
||||
run_func_module = partial(_linear_module_mem_test, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func_module, nprocs=world_size)
|
||||
|
||||
|
||||
def _linear_function_mem_test(rank, world_size, port):
|
||||
"""This function is for linear memory test
|
||||
Test and print real memory cost and estimated, this test will not be executed except with the tag AUTO_PARALLEL
|
||||
|
||||
Args:
|
||||
rank: device rank
|
||||
bias: indicate whether linear module need bias
|
||||
world_size: number of devices
|
||||
port: port for initializing process group
|
||||
"""
|
||||
disable_existing_loggers()
|
||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
model = MyModule().cuda()
|
||||
input = torch.rand(8, 8, 16, 64).cuda()
|
||||
input.requires_grad = True
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
mesh_shape = (2, 2)
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||
|
||||
# memory test
|
||||
mem_test_for_node_strategy(rank=rank,
|
||||
model=model,
|
||||
device_mesh=device_mesh,
|
||||
node_index=2,
|
||||
strategy_number=13,
|
||||
input_args=[input],
|
||||
meta_arg_names=["input"])
|
||||
|
||||
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_linear_function_meta_concrete_info_match():
|
||||
world_size = 4
|
||||
run_func_module = partial(_linear_function_mem_test, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func_module, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_linear_meta_concrete_info_match()
|
||||
# test_linear_module_meta_concrete_info_match()
|
||||
test_linear_function_meta_concrete_info_match()
|
||||
|
|
|
@ -7,6 +7,7 @@ from torch.fx import GraphModule
|
|||
|
||||
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass
|
||||
from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDataType
|
||||
from colossalai.auto_parallel.tensor_shard.solver import SolverOptions, StrategiesConstructor
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
|
@ -49,8 +50,9 @@ def mem_test_for_node_strategy(rank: int,
|
|||
|
||||
# construct the strategy for the output node
|
||||
placeholder_strategy = list(graph.nodes)[-1].strategies_vector[0]
|
||||
|
||||
output_key = next(key for key in target_node.strategies_vector[strategy_index].sharding_specs.keys()
|
||||
if key in placeholder_strategy.sharding_specs)
|
||||
if key.type == OperationDataType.OUTPUT)
|
||||
placeholder_strategy.sharding_specs[output_key] = target_node.strategies_vector[strategy_index].sharding_specs[
|
||||
output_key]
|
||||
|
||||
|
@ -104,8 +106,12 @@ def mem_test_for_node_strategy(rank: int,
|
|||
)
|
||||
|
||||
# estimated memory
|
||||
metainfo = MetaInfo(target_node.strategies_vector[strategy_index],
|
||||
target_node.graph.owning_module.get_submodule(target_node.target))
|
||||
if target_node.op == "call_module":
|
||||
metainfo = MetaInfo(target_node.strategies_vector[strategy_index],
|
||||
target_node.graph.owning_module.get_submodule(target_node.target))
|
||||
else:
|
||||
metainfo = MetaInfo(target_node.strategies_vector[strategy_index], target_node.target)
|
||||
|
||||
print("estimated memory:")
|
||||
print(
|
||||
f"forward activation: {metainfo.memory_cost.fwd.activation / 1024} kb, forward param: {metainfo.memory_cost.fwd.parameter / 1024} kb"
|
||||
|
|
Loading…
Reference in New Issue