mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] Add F.conv metainfo (#2069)
* [fx] metainfo class for auto parallel * [fx] add unit test for linear metainfo * [fx] fix bwd param for linear * [fx] modify unit test * [fx] modify unit test * [fx] modify import * [fx] modify import * [fx] modify import * [fx] move meta profiler to auto parallel * [fx] add conv metainfo class * [fx] restore profiler * [fx] restore meta profiler * [autoparallel] modify unit test * [fx] modify unit test * [autoparallel] add batchnorm metainfo class * [autoparallel] fix batchnorm unit test function declaration * [fx] restore profiler * [fx] add relu metainfo class * [fx] restore profiler * [autoparallel] modify metainfo input * [autoparallel] add pooling metainfo * [autoparallel] add F.linear metainfo generator * [autoparallel] add binary elementwise metainfo * [fx] recover profiler * [autoparallel] fix forward memory calculation * [autoparallel] modify constants.py * [autoparallel] remove redundant print * [autoparallel] add F.conv metainfo * [autoparallel] linear fixpull/2083/head
parent
f123476666
commit
cf0268da93
|
@ -22,6 +22,9 @@ __all__ = ['convnd_meta_info']
|
|||
@meta_register.register(torch.nn.Conv1d)
|
||||
@meta_register.register(torch.nn.Conv2d)
|
||||
@meta_register.register(torch.nn.Conv3d)
|
||||
@meta_register.register(torch.nn.functional.conv1d)
|
||||
@meta_register.register(torch.nn.functional.conv2d)
|
||||
@meta_register.register(torch.nn.functional.conv3d)
|
||||
def convnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]:
|
||||
"""torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d meta info generator
|
||||
The atens graph of torch.nn.Convnd with bias is
|
||||
|
@ -57,12 +60,19 @@ def convnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
|
|||
has_bias: bool = False
|
||||
input_tensor = next(filter(lambda x: x.type == OperationDataType.ARG, args)).data
|
||||
output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data
|
||||
weight_tensor = next(filter(lambda x: x.name == 'weight', args)).data
|
||||
weight_tensors = [x.data for x in args if x.type == OperationDataType.PARAM]
|
||||
|
||||
# check if conv has bias
|
||||
if len(args) == 4:
|
||||
bias_tensor = next(filter(lambda x: x.name == 'bias', args)).data
|
||||
if len(weight_tensors) > 1:
|
||||
has_bias = True
|
||||
# bias tensor's shape only has one dimension
|
||||
if len(weight_tensors[0].shape) == 1:
|
||||
bias_tensor, weight_tensor = weight_tensors
|
||||
else:
|
||||
weight_tensor, bias_tensor = weight_tensors
|
||||
|
||||
else:
|
||||
weight_tensor = weight_tensors[0]
|
||||
|
||||
# construct input args for forward
|
||||
fwd_args = [None] * 9
|
||||
|
|
|
@ -143,7 +143,7 @@ def linear_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
|
|||
# NOTE: Linear don't have buffer and temp in forward and backward phase
|
||||
# the forward activation cost is the size of output_tensor, parameter cost is the size of weight_tensor
|
||||
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
|
||||
fwd_memory_cost = MemoryCost(activation=activation_size(output_tensor),
|
||||
fwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, output_tensor]),
|
||||
parameter=activation_size(weight_tensor),
|
||||
temp=0,
|
||||
buffer=0)
|
||||
|
|
|
@ -15,6 +15,16 @@ from colossalai.utils import free_port
|
|||
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy
|
||||
|
||||
|
||||
class ConvFunctionModule(nn.Module):
|
||||
|
||||
def __init__(self, in_channels=4, out_channels=64, kernel_size=3):
|
||||
super().__init__()
|
||||
self.conv_weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size))
|
||||
|
||||
def forward(self, input):
|
||||
return nn.functional.conv2d(input, self.conv_weight)
|
||||
|
||||
|
||||
def _conv_module_mem_test(rank, bias, world_size, port):
|
||||
"""This function is for conv memory test
|
||||
Test and print real memory cost and estimated, this test will not be executed except with the tag AUTO_PARALLEL
|
||||
|
@ -57,5 +67,47 @@ def test_conv_meta_concrete_info_match(bias=False):
|
|||
mp.spawn(run_func_module, nprocs=world_size)
|
||||
|
||||
|
||||
def _conv_function_mem_test(rank, world_size, port):
|
||||
"""This function is for conv function memory test
|
||||
Test and print real memory cost and estimated, this test will not be executed except with the tag AUTO_PARALLEL
|
||||
|
||||
Args:
|
||||
rank: device rank
|
||||
bias: indicate whether conv module need bias
|
||||
world_size: number of devices
|
||||
port: port for initializing process group
|
||||
"""
|
||||
disable_existing_loggers()
|
||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
model = ConvFunctionModule().cuda()
|
||||
input = torch.rand(4, 4, 64, 64).cuda()
|
||||
input.requires_grad = True
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
mesh_shape = (2, 2)
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||
|
||||
# index of target node in computation graph
|
||||
node_index = 2
|
||||
# total number of target node strategies
|
||||
strategy_number = 16
|
||||
mem_test_for_node_strategy(rank=rank,
|
||||
model=model,
|
||||
device_mesh=device_mesh,
|
||||
node_index=node_index,
|
||||
strategy_number=strategy_number,
|
||||
input_args=[input],
|
||||
meta_arg_names=['input'])
|
||||
|
||||
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_conv_function_concrete_info_match():
|
||||
world_size = 4
|
||||
run_func_module = partial(_conv_function_mem_test, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func_module, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_conv_meta_concrete_info_match()
|
||||
# test_conv_meta_concrete_info_match()
|
||||
test_conv_function_concrete_info_match()
|
||||
|
|
|
@ -92,7 +92,7 @@ def _linear_function_mem_test(rank, world_size, port):
|
|||
model=model,
|
||||
device_mesh=device_mesh,
|
||||
node_index=2,
|
||||
strategy_number=13,
|
||||
strategy_number=23,
|
||||
input_args=[input],
|
||||
meta_arg_names=["input"])
|
||||
|
||||
|
|
Loading…
Reference in New Issue