[autoparallel] Patch meta information of `torch.nn.Embedding` (#2760)

* [autoparallel] embedding metainfo

* [autoparallel] fix function name in test_activation_metainfo

* [autoparallel] undo changes in activation metainfo and related tests
pull/2766/head^2
Boyuan Yao 2023-02-17 10:39:48 +08:00 committed by GitHub
parent 8e3f66a0d1
commit a2b43e393d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 130 additions and 0 deletions

View File

@ -1,6 +1,7 @@
from .activation import *
from .binary_elementwise_ops import *
from .conv import *
from .embedding import *
from .linear import *
from .norm import *
from .pooling import *

View File

@ -0,0 +1,52 @@
from typing import List, Tuple
import torch
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
from colossalai.fx.profiler.memory_utils import activation_size
from colossalai.fx.profiler.opcount import flop_mapping
from ..registry import meta_register
__all__ = ["embedding_meta_info"]
@meta_register.register(torch.nn.Embedding)
def embedding_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]:
"""torch.nn.Embedding metainfo generator
Returns:
Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs
"""
input_tensor = next(filter(lambda x: x.type == OperationDataType.ARG, args)).data
weight_tensor = next(filter(lambda x: x.type == OperationDataType.PARAM, args)).data
output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data
# compute cost
fwd_compute_cost = flop_mapping[torch.ops.aten.embedding.default]([weight_tensor, input_tensor], [output_tensor])
bwd_compute_cost = flop_mapping[torch.ops.aten.embedding_dense_backward.default]([output_tensor, weight_tensor],
[weight_tensor])
compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost)
# memory cost
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
# NOTE: during the backward phase of torch.nn.Embedding, it seems when the input is large enough, it will
# have a temp memory which is kind of weird and we don't know the reason yet, so currently we just assume
# that there will be no temp memory, as the temp memory is significantly smaller than the gradient memory
fwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, output_tensor]),
parameter=0,
temp=0,
buffer=0)
bwd_memory_cost = MemoryCost(activation=activation_size([weight_tensor]), parameter=0, temp=0, buffer=0)
total_memory_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation)
memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_memory_cost)
# store fwd_in, fwd_buffer, fwd_out
fwd_in = [torch.zeros_like(input_tensor)]
fwd_buffer = []
fwd_out = [torch.zeros_like(output_tensor)]
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out

View File

@ -0,0 +1,77 @@
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.auto_parallel.tensor_shard.node_handler import LinearModuleHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
MemoryCost,
OperationData,
OperationDataType,
ShardingStrategy,
StrategiesVector,
TrainCycleItem,
)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results
if torch.__version__ >= '1.12.0':
from colossalai.auto_parallel.meta_profiler import MetaInfo, meta_register
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations")
def test_embedding_meta_info():
meta_func = meta_register.get(torch.nn.Embedding)
# construct meta tensors
input_tensor = torch.randint(0, 50256, (8, 1024), device="meta")
weight_tensor = torch.rand(50257, 1024, device="meta")
output_tensor = torch.rand(8, 1024, 1024, device="meta")
# construct operation data
input_data = OperationData(name="input", type=OperationDataType.ARG, data=input_tensor)
weight_data = OperationData(name="weight", type=OperationDataType.PARAM, data=weight_tensor)
output_data = OperationData(name="output", type=OperationDataType.OUTPUT, data=output_tensor)
# construct args and kwargs
args = [input_data, weight_data, output_data]
kwargs = {'inplace': False}
# estimated results
compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out = meta_func(*args, **kwargs)
# actual results
input_real_tensor = torch.randint(0, 50256, (8, 1024), device="cuda")
embedding_module = torch.nn.Embedding(50257, 1024).cuda()
# fwd
torch.cuda.reset_peak_memory_stats()
mem_stamp0 = torch.cuda.memory_allocated()
output_real_tensor = embedding_module(input_real_tensor)
fwd_allocated = torch.cuda.memory_allocated() - mem_stamp0
fwd_peak = torch.cuda.max_memory_allocated() - mem_stamp0
# bwd
upstream_grad = torch.rand_like(output_real_tensor)
torch.cuda.reset_peak_memory_stats()
mem_stamp0 = torch.cuda.memory_allocated()
torch.autograd.backward(output_real_tensor, upstream_grad)
bwd_allocated = torch.cuda.memory_allocated() - mem_stamp0
bwd_peak = torch.cuda.max_memory_allocated() - mem_stamp0
print_results([input_real_tensor], [output_real_tensor], compute_cost, memory_cost, fwd_allocated, fwd_peak,
bwd_allocated, bwd_peak)
if __name__ == '__main__':
test_embedding_meta_info()