mirror of https://github.com/hpcaitech/ColossalAI
[fx] Add linear metainfo class for auto parallel (#1783)
* [fx] metainfo class for auto parallel * [fx] add unit test for linear metainfo * [fx] fix bwd param for linear * [fx] modify unit test * [fx] modify unit test * [fx] modify import * [fx] modify import * [fx] modify import * [fx] move meta profiler to auto parallelpull/1792/head
parent
e8a9bebc87
commit
05ce3d369f
|
@ -0,0 +1,3 @@
|
||||||
|
from .meta_registry import *
|
||||||
|
from .metainfo import *
|
||||||
|
from .registry import meta_register
|
|
@ -0,0 +1 @@
|
||||||
|
from .linear import *
|
|
@ -0,0 +1,157 @@
|
||||||
|
from typing import Callable, Dict, List, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||||
|
MemoryCost,
|
||||||
|
OperationData,
|
||||||
|
OperationDataType,
|
||||||
|
ShardingStrategy,
|
||||||
|
StrategiesVector,
|
||||||
|
TrainCycleItem,
|
||||||
|
)
|
||||||
|
from colossalai.fx.profiler.memory_utils import activation_size
|
||||||
|
from colossalai.fx.profiler.opcount import flop_mapping
|
||||||
|
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||||
|
|
||||||
|
from ..registry import meta_register
|
||||||
|
|
||||||
|
__all__ = ['linear_meta_info']
|
||||||
|
|
||||||
|
|
||||||
|
@meta_register.register(torch.nn.Linear)
|
||||||
|
def linear_meta_info(*args) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]:
|
||||||
|
"""torch.nn.Linear meta info generator
|
||||||
|
The atens graph of torch.nn.Linear with bias is
|
||||||
|
graph():
|
||||||
|
%input_2 : [#users=2] = placeholder[target=placeholder](default=)
|
||||||
|
%addmm_default : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (None, %input_2, None), kwargs = {})
|
||||||
|
%zeros_like_default : [#users=3] = call_function[target=torch.ops.aten.zeros_like.default](args = (%addmm_default,), kwargs = {dtype: None, layout: None, device: None, pin_memory: None})
|
||||||
|
%detach_default : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%input_2,), kwargs = {})
|
||||||
|
%mm_default : [#users=1] = call_function[target=torch.ops.aten.mm.default](args = (%zeros_like_default, None), kwargs = {})
|
||||||
|
%t_default : [#users=1] = call_function[target=torch.ops.aten.t.default](args = (%zeros_like_default,), kwargs = {})
|
||||||
|
%mm_default_1 : [#users=1] = call_function[target=torch.ops.aten.mm.default](args = (%t_default, %detach_default), kwargs = {})
|
||||||
|
%t_default_1 : [#users=1] = call_function[target=torch.ops.aten.t.default](args = (%mm_default_1,), kwargs = {})
|
||||||
|
%sum_dim_int_list : [#users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%zeros_like_default, [None], None), kwargs = {})
|
||||||
|
%view_default : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%sum_dim_int_list, [None]), kwargs = {})
|
||||||
|
%detach_default_1 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%view_default,), kwargs = {})
|
||||||
|
%detach_default_2 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_1,), kwargs = {})
|
||||||
|
%detach_default_3 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%mm_default,), kwargs = {})
|
||||||
|
%detach_default_4 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_3,), kwargs = {})
|
||||||
|
%t_default_2 : [#users=1] = call_function[target=torch.ops.aten.t.default](args = (%t_default_1,), kwargs = {})
|
||||||
|
%detach_default_5 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%t_default_2,), kwargs = {})
|
||||||
|
%detach_default_6 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_5,), kwargs = {})
|
||||||
|
|
||||||
|
The one without bias is
|
||||||
|
graph():
|
||||||
|
%input_2 : [#users=2] = placeholder[target=placeholder](default=)
|
||||||
|
%mm_default : [#users=1] = call_function[target=torch.ops.aten.mm.default](args = (%input_2, None), kwargs = {})
|
||||||
|
%zeros_like_default : [#users=2] = call_function[target=torch.ops.aten.zeros_like.default](args = (%mm_default,), kwargs = {dtype: None, layout: None, device: None, pin_memory: None})
|
||||||
|
%detach_default : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%input_2,), kwargs = {})
|
||||||
|
%t_default : [#users=1] = call_function[target=torch.ops.aten.t.default](args = (%zeros_like_default,), kwargs = {})
|
||||||
|
%mm_default_1 : [#users=1] = call_function[target=torch.ops.aten.mm.default](args = (%t_default, %detach_default), kwargs = {})
|
||||||
|
%t_default_1 : [#users=1] = call_function[target=torch.ops.aten.t.default](args = (%mm_default_1,), kwargs = {})
|
||||||
|
%mm_default_2 : [#users=1] = call_function[target=torch.ops.aten.mm.default](args = (%zeros_like_default, None), kwargs = {})
|
||||||
|
%detach_default_1 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%mm_default_2,), kwargs = {})
|
||||||
|
%detach_default_2 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_1,), kwargs = {})
|
||||||
|
%t_default_2 : [#users=1] = call_function[target=torch.ops.aten.t.default](args = (%t_default_1,), kwargs = {})
|
||||||
|
%detach_default_3 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%t_default_2,), kwargs = {})
|
||||||
|
%detach_default_4 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_3,), kwargs = {})
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[TrainCycleItem, TrainCycleItem, bool]: compute cost, memory cost and save input flag
|
||||||
|
"""
|
||||||
|
|
||||||
|
has_bias: bool = False
|
||||||
|
input_tensor = next(filter(lambda x: x.type == OperationDataType.ARG, args)).data
|
||||||
|
output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data
|
||||||
|
weight_tensor = next(filter(lambda x: x.name == 'weight', args)).data
|
||||||
|
|
||||||
|
# process the dimension of input and output
|
||||||
|
if len(input_tensor.shape) > 2:
|
||||||
|
input_tensor: torch.Tensor
|
||||||
|
input_tensor = input_tensor.view(-1, input_tensor.shape[-1])
|
||||||
|
|
||||||
|
if len(output_tensor.shape) > 2:
|
||||||
|
output_tensor: torch.Tensor
|
||||||
|
output_tensor = output_tensor.view(-1, output_tensor.shape[-1])
|
||||||
|
|
||||||
|
if len(args) == 4:
|
||||||
|
bias_tensor = next(filter(lambda x: x.name == 'bias', args)).data
|
||||||
|
has_bias = True
|
||||||
|
|
||||||
|
if has_bias:
|
||||||
|
# calculate cost with bias
|
||||||
|
# the fwd op with compute cost is addmm
|
||||||
|
# the bwd op with compute cost is mm * 2 and sum.dim_IntList
|
||||||
|
|
||||||
|
# calculate compute cost
|
||||||
|
fwd_compute_cost = flop_mapping[torch.ops.aten.addmm.default](
|
||||||
|
[bias_tensor, input_tensor, torch.transpose(weight_tensor, 0, 1)], (output_tensor,))
|
||||||
|
bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]([output_tensor, weight_tensor], (input_tensor,)) + \
|
||||||
|
flop_mapping[torch.ops.aten.mm.default]([torch.transpose(output_tensor, 0, 1), input_tensor], (weight_tensor,)) + \
|
||||||
|
flop_mapping[torch.ops.aten.sum.dim_IntList]([output_tensor], (bias_tensor,))
|
||||||
|
compute_cost = TrainCycleItem(fwd=fwd_compute_cost,
|
||||||
|
bwd=bwd_compute_cost,
|
||||||
|
total=fwd_compute_cost + bwd_compute_cost)
|
||||||
|
|
||||||
|
# calculate memory cost
|
||||||
|
# NOTE: Linear don't have buffer and temp in forward and backward phase
|
||||||
|
# the forward activation cost is the size of output_tensor, parameter cost is the size of weight_tensor and bias_tensor
|
||||||
|
fwd_memory_cost = MemoryCost(activation=activation_size(output_tensor),
|
||||||
|
parameter=activation_size(weight_tensor) + activation_size(bias_tensor),
|
||||||
|
temp=0,
|
||||||
|
buffer=0)
|
||||||
|
|
||||||
|
# the backward activation cost is the size of input_tensor, weight_tensor and bias_tensor, parameter cost is 0
|
||||||
|
bwd_memory_cost = MemoryCost(activation=activation_size(input_tensor) + activation_size(weight_tensor) +
|
||||||
|
activation_size(bias_tensor),
|
||||||
|
parameter=activation_size(weight_tensor) + activation_size(bias_tensor),
|
||||||
|
temp=0,
|
||||||
|
buffer=0)
|
||||||
|
|
||||||
|
# total cost is to sum the forward and backward cost
|
||||||
|
total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
|
||||||
|
parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter)
|
||||||
|
|
||||||
|
memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# calculate cost without bias
|
||||||
|
# the fwd op with compute cost is mm
|
||||||
|
# the bwd op with compute cost is mm * 2
|
||||||
|
|
||||||
|
# calculate compute cost
|
||||||
|
fwd_compute_cost = flop_mapping[torch.ops.aten.mm.default](
|
||||||
|
[input_tensor, torch.transpose(weight_tensor, 0, 1)], (output_tensor,))
|
||||||
|
bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]([output_tensor, weight_tensor], (input_tensor,)) + \
|
||||||
|
flop_mapping[torch.ops.aten.mm.default]([torch.transpose(output_tensor, 0, 1), input_tensor], (weight_tensor,))
|
||||||
|
|
||||||
|
compute_cost = TrainCycleItem(fwd=fwd_compute_cost,
|
||||||
|
bwd=bwd_compute_cost,
|
||||||
|
total=fwd_compute_cost + bwd_compute_cost)
|
||||||
|
|
||||||
|
# calculate memory cost
|
||||||
|
# NOTE: Linear don't have buffer and temp in forward and backward phase
|
||||||
|
# the forward activation cost is the size of output_tensor, parameter cost is the size of weight_tensor
|
||||||
|
fwd_memory_cost = MemoryCost(activation=activation_size(output_tensor),
|
||||||
|
parameter=activation_size(weight_tensor),
|
||||||
|
temp=0,
|
||||||
|
buffer=0)
|
||||||
|
|
||||||
|
# the backward activation cost is the size of input_tensor and weight_tensor, parameter cost is 0
|
||||||
|
bwd_memory_cost = MemoryCost(activation=activation_size(input_tensor) + activation_size(weight_tensor),
|
||||||
|
parameter=activation_size(weight_tensor),
|
||||||
|
temp=0,
|
||||||
|
buffer=0)
|
||||||
|
|
||||||
|
# total cost is to sum the forward and backward cost
|
||||||
|
total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
|
||||||
|
parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter)
|
||||||
|
|
||||||
|
memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost)
|
||||||
|
|
||||||
|
# store fwd_in
|
||||||
|
fwd_in = [input_tensor]
|
||||||
|
|
||||||
|
return compute_cost, memory_cost, fwd_in
|
|
@ -0,0 +1,101 @@
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||||
|
MemoryCost,
|
||||||
|
OperationData,
|
||||||
|
OperationDataType,
|
||||||
|
ShardingStrategy,
|
||||||
|
StrategiesVector,
|
||||||
|
TrainCycleItem,
|
||||||
|
)
|
||||||
|
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||||
|
|
||||||
|
from .registry import meta_register
|
||||||
|
|
||||||
|
__all__ = ['MetaInfo']
|
||||||
|
|
||||||
|
|
||||||
|
class MetaInfo:
|
||||||
|
"""MetaInfo class
|
||||||
|
This class is used to store meta info based on sharding strategy and the given
|
||||||
|
target function.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, strategy: ShardingStrategy = None, target: Callable = None) -> None:
|
||||||
|
# compute cost of forward and backward computation
|
||||||
|
self.compute_cost: TrainCycleItem
|
||||||
|
|
||||||
|
# compute memory cost of forward and backward phase
|
||||||
|
self.memory_cost: TrainCycleItem
|
||||||
|
|
||||||
|
# list of input tensors
|
||||||
|
self.fwd_in: list[OperationData]
|
||||||
|
|
||||||
|
# sharding strategy
|
||||||
|
self._strategy = strategy
|
||||||
|
|
||||||
|
# target function
|
||||||
|
self._target = target
|
||||||
|
|
||||||
|
# compute metainfo if possible
|
||||||
|
if self._strategy is not None and self._target is not None:
|
||||||
|
self.compute_metainfo()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def strategy(self) -> ShardingStrategy:
|
||||||
|
return self._strategy
|
||||||
|
|
||||||
|
@property
|
||||||
|
def target(self) -> Callable:
|
||||||
|
return self._target
|
||||||
|
|
||||||
|
@strategy.setter
|
||||||
|
def strategy(self, strategy: ShardingStrategy) -> None:
|
||||||
|
self._strategy = strategy
|
||||||
|
if self._strategy is not None and self._target is not None:
|
||||||
|
self.compute_metainfo()
|
||||||
|
|
||||||
|
@target.setter
|
||||||
|
def target(self, target: Callable) -> None:
|
||||||
|
self._target = target
|
||||||
|
if self._strategy is not None and self._target is not None:
|
||||||
|
self.compute_metainfo()
|
||||||
|
|
||||||
|
def compute_sharded_tensor(self, operation_data: OperationData, sharding_spec: ShardingSpec) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Compute sharded meta tensor based on the given data and sharding spec.
|
||||||
|
"""
|
||||||
|
shard_sequnce = sharding_spec.sharding_sequence
|
||||||
|
device_mesh = sharding_spec.device_mesh
|
||||||
|
shape = operation_data.data.shape
|
||||||
|
|
||||||
|
new_shape = []
|
||||||
|
for dim, shard in zip(shape, shard_sequnce):
|
||||||
|
if shard.is_replica:
|
||||||
|
# replica
|
||||||
|
new_shape.append(dim)
|
||||||
|
else:
|
||||||
|
# sharded according to device_mesh shape
|
||||||
|
new_shape.append(dim // np.prod(np.array([device_mesh.mesh_shape[i] for i in shard.shard_list])))
|
||||||
|
|
||||||
|
return OperationData(name=operation_data.name,
|
||||||
|
data=torch.zeros(new_shape, device="meta"),
|
||||||
|
type=operation_data.type,
|
||||||
|
logical_shape=operation_data.logical_shape)
|
||||||
|
|
||||||
|
def compute_metainfo(self):
|
||||||
|
"""
|
||||||
|
Compute meta info based on sharding strategy and the given target function.
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert meta_register.has(self._target), f'{self._target} not found in the meta registry'
|
||||||
|
meta_func = meta_register.get(self._target)
|
||||||
|
|
||||||
|
# construct args for meta_func
|
||||||
|
args = [self.compute_sharded_tensor(k, v) for k, v in self._strategy.sharding_specs.items()]
|
||||||
|
|
||||||
|
# compute metainfo with meta_func
|
||||||
|
self.compute_cost, self.memory_cost, self.fwd_in = meta_func(*args)
|
|
@ -0,0 +1,32 @@
|
||||||
|
__all__ = ['Registry']
|
||||||
|
|
||||||
|
|
||||||
|
class Registry:
|
||||||
|
|
||||||
|
def __init__(self, name):
|
||||||
|
self.name = name
|
||||||
|
self.store = {}
|
||||||
|
|
||||||
|
def register(self, source):
|
||||||
|
|
||||||
|
def wrapper(func):
|
||||||
|
if isinstance(source, (list, tuple)):
|
||||||
|
# support register a list of items for this func
|
||||||
|
for element in source:
|
||||||
|
self.store[element] = func
|
||||||
|
else:
|
||||||
|
self.store[source] = func
|
||||||
|
return func
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
def get(self, source):
|
||||||
|
assert source in self.store, f'{source} not found in the {self.name} registry'
|
||||||
|
target = self.store[source]
|
||||||
|
return target
|
||||||
|
|
||||||
|
def has(self, source):
|
||||||
|
return source in self.store
|
||||||
|
|
||||||
|
|
||||||
|
meta_register = Registry('meta')
|
|
@ -79,9 +79,12 @@ class MemoryCost:
|
||||||
Args:
|
Args:
|
||||||
activation (int): the memory cost incurred by the activations in bytes.
|
activation (int): the memory cost incurred by the activations in bytes.
|
||||||
parameter (int): the memory cost incurred by the module parameter in bytes.
|
parameter (int): the memory cost incurred by the module parameter in bytes.
|
||||||
|
temp (int): the memory cost incurred by the temporary tensors in bytes.
|
||||||
|
buffer (int): the memory cost incurred by the module buffer in bytes.
|
||||||
"""
|
"""
|
||||||
activation: int = 0
|
activation: int = 0
|
||||||
parameter: int = 0
|
parameter: int = 0
|
||||||
|
temp: int = 0
|
||||||
buffer: int = 0
|
buffer: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -32,7 +32,7 @@ def addmm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number:
|
||||||
# inputs is a list of length 3.
|
# inputs is a list of length 3.
|
||||||
input_shapes = [v.shape for v in inputs[1:3]]
|
input_shapes = [v.shape for v in inputs[1:3]]
|
||||||
# input_shapes[0]: [batch size, input feature dimension]
|
# input_shapes[0]: [batch size, input feature dimension]
|
||||||
# input_shapes[1]: [batch size, output feature dimension]
|
# input_shapes[1]: [input feature dimension, output feature dimension]
|
||||||
assert len(input_shapes[0]) == 2, input_shapes[0]
|
assert len(input_shapes[0]) == 2, input_shapes[0]
|
||||||
assert len(input_shapes[1]) == 2, input_shapes[1]
|
assert len(input_shapes[1]) == 2, input_shapes[1]
|
||||||
batch_size, input_dim = input_shapes[0]
|
batch_size, input_dim = input_shapes[0]
|
||||||
|
|
|
@ -0,0 +1,97 @@
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from colossalai.auto_parallel.tensor_shard.node_handler import LinearModuleHandler
|
||||||
|
from colossalai.auto_parallel.tensor_shard.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||||
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
|
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||||
|
from colossalai.initialize import launch
|
||||||
|
from colossalai.logging import disable_existing_loggers
|
||||||
|
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||||
|
from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use
|
||||||
|
from colossalai.utils import free_port
|
||||||
|
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy
|
||||||
|
|
||||||
|
if torch.__version__ >= '1.12.0':
|
||||||
|
from colossalai.auto_parallel.meta_profiler import MetaInfo, meta_register
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='PyTorch version is too low')
|
||||||
|
@parameterize('bias', [True, False])
|
||||||
|
def test_linear_metainfo(bias):
|
||||||
|
model = nn.Sequential(nn.Linear(16, 32, bias=bias).to('meta'))
|
||||||
|
|
||||||
|
tracer = ColoTracer()
|
||||||
|
graph = tracer.trace(model, meta_args={"input": torch.rand(2, 2, 4, 16).to('meta')})
|
||||||
|
gm = ColoGraphModule(model, graph)
|
||||||
|
physical_mesh_id = torch.arange(0, 4)
|
||||||
|
|
||||||
|
mesh_shape = (2, 2)
|
||||||
|
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||||
|
linear_mod_node = list(graph.nodes)[1]
|
||||||
|
strategies_vector = StrategiesVector(linear_mod_node)
|
||||||
|
|
||||||
|
# build handler
|
||||||
|
handler = LinearModuleHandler(node=linear_mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector)
|
||||||
|
|
||||||
|
# build strategy
|
||||||
|
strategies_vector = handler.register_strategy(compute_resharding_cost=False)
|
||||||
|
|
||||||
|
# assert module is registered
|
||||||
|
assert meta_register.has(linear_mod_node.graph.owning_module.get_submodule(linear_mod_node.target).__class__)
|
||||||
|
|
||||||
|
# check metainfo
|
||||||
|
for strategy in strategies_vector:
|
||||||
|
strategy: ShardingStrategy
|
||||||
|
try:
|
||||||
|
metainfo = MetaInfo(strategy,
|
||||||
|
linear_mod_node.graph.owning_module.get_submodule(linear_mod_node.target).__class__)
|
||||||
|
|
||||||
|
except:
|
||||||
|
raise RuntimeError(f"Failed to compute metainfo for {strategy}")
|
||||||
|
|
||||||
|
|
||||||
|
def _linear_mem_test(rank, bias, world_size, port):
|
||||||
|
"""This function is for linear memory test
|
||||||
|
Test and print real memory cost and estimated, this test will not be executed
|
||||||
|
in unit test.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
bias (bool, optional): Indicate whether we need bias for Linear. Defaults to True.
|
||||||
|
"""
|
||||||
|
disable_existing_loggers()
|
||||||
|
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
|
model = nn.Sequential(nn.Linear(64, 128, bias=bias)).cuda()
|
||||||
|
input = torch.rand(8, 8, 16, 64).cuda()
|
||||||
|
input.requires_grad = True
|
||||||
|
physical_mesh_id = torch.arange(0, 4)
|
||||||
|
mesh_shape = (2, 2)
|
||||||
|
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||||
|
|
||||||
|
# memory test
|
||||||
|
mem_test_for_node_strategy(rank=rank,
|
||||||
|
model=model,
|
||||||
|
device_mesh=device_mesh,
|
||||||
|
node_index=1,
|
||||||
|
strategy_number=13,
|
||||||
|
input_args=[input],
|
||||||
|
meta_arg_names=["input"])
|
||||||
|
|
||||||
|
|
||||||
|
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||||
|
@pytest.mark.dist
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
|
def test_linear_meta_concrete_info_match(bias=False):
|
||||||
|
world_size = 4
|
||||||
|
run_func_module = partial(_linear_mem_test, bias=bias, world_size=world_size, port=free_port())
|
||||||
|
mp.spawn(run_func_module, nprocs=world_size)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# test_linear_metainfo()
|
||||||
|
# _linear_mem_test(bias=True)
|
||||||
|
test_linear_meta_concrete_info_match()
|
|
@ -0,0 +1,121 @@
|
||||||
|
import copy
|
||||||
|
from pprint import pprint
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.fx import GraphModule
|
||||||
|
|
||||||
|
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass
|
||||||
|
from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass
|
||||||
|
from colossalai.auto_parallel.tensor_shard.solver import SolverOptions, StrategiesConstructor
|
||||||
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
|
from colossalai.fx.tracer.tracer import ColoTracer
|
||||||
|
|
||||||
|
if torch.__version__ >= '1.12.0':
|
||||||
|
from colossalai.auto_parallel.meta_profiler import MetaInfo
|
||||||
|
|
||||||
|
|
||||||
|
def mem_test_for_node_strategy(rank: int,
|
||||||
|
model: torch.nn.Module,
|
||||||
|
device_mesh: DeviceMesh,
|
||||||
|
node_index: int,
|
||||||
|
strategy_number: int,
|
||||||
|
input_args: List[torch.Tensor],
|
||||||
|
meta_arg_names: List[str],
|
||||||
|
input_kwargs: Dict[str, torch.Tensor] = {}):
|
||||||
|
for strategy_index in range(strategy_number):
|
||||||
|
# We need to copy the model to avoid do backward more than once in same graph
|
||||||
|
model_to_shard, args_to_shard, kwargs_to_shard = copy.deepcopy(model), copy.deepcopy(input_args), copy.deepcopy(
|
||||||
|
input_kwargs)
|
||||||
|
|
||||||
|
tracer = ColoTracer()
|
||||||
|
input_sample = {}
|
||||||
|
for input_arg, meta_arg_name in zip(input_args, meta_arg_names):
|
||||||
|
input_sample[meta_arg_name] = torch.rand(input_arg.shape).to('meta')
|
||||||
|
for meta_kwarg_name, input_kwarg in input_kwargs.items():
|
||||||
|
input_sample[meta_kwarg_name] = torch.rand(input_kwarg.shape).to('meta')
|
||||||
|
graph = tracer.trace(root=model_to_shard, meta_args=input_sample)
|
||||||
|
gm = GraphModule(model_to_shard, graph, model_to_shard.__class__.__name__)
|
||||||
|
solver_options = SolverOptions(fast=True)
|
||||||
|
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
||||||
|
strategies_constructor.build_strategies_and_cost()
|
||||||
|
target_node = list(graph.nodes)[node_index]
|
||||||
|
|
||||||
|
# solution construction
|
||||||
|
# construct the strategy for the target node
|
||||||
|
solution_len = len(strategies_constructor.leaf_strategies)
|
||||||
|
solution = [0] * solution_len
|
||||||
|
solution[node_index] = strategy_index
|
||||||
|
|
||||||
|
# construct the strategy for the output node
|
||||||
|
placeholder_strategy = list(graph.nodes)[-1].strategies_vector[0]
|
||||||
|
output_key = next(key for key in target_node.strategies_vector[strategy_index].sharding_specs.keys()
|
||||||
|
if key in placeholder_strategy.sharding_specs)
|
||||||
|
placeholder_strategy.sharding_specs[output_key] = target_node.strategies_vector[strategy_index].sharding_specs[
|
||||||
|
output_key]
|
||||||
|
|
||||||
|
gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(
|
||||||
|
gm, solution, device_mesh)
|
||||||
|
gm = runtime_apply_pass(gm)
|
||||||
|
gm.recompile()
|
||||||
|
gm: GraphModule
|
||||||
|
|
||||||
|
if rank == 0:
|
||||||
|
print("=======================")
|
||||||
|
print(f"#strategy_index: {strategy_index}")
|
||||||
|
pprint(target_node.strategies_vector[strategy_index])
|
||||||
|
|
||||||
|
# warmup
|
||||||
|
with torch.no_grad():
|
||||||
|
output = gm(*args_to_shard,
|
||||||
|
sharding_spec_convert_dict=sharding_spec_dict,
|
||||||
|
origin_node_sharding_spec_dict=origin_spec_dict,
|
||||||
|
comm_actions_dict=comm_actions_dict,
|
||||||
|
**kwargs_to_shard)
|
||||||
|
|
||||||
|
del output
|
||||||
|
# forward memory compare
|
||||||
|
if rank == 0:
|
||||||
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
mem_stamp0 = torch.cuda.memory_allocated()
|
||||||
|
output = gm(*args_to_shard,
|
||||||
|
sharding_spec_convert_dict=sharding_spec_dict,
|
||||||
|
origin_node_sharding_spec_dict=origin_spec_dict,
|
||||||
|
comm_actions_dict=comm_actions_dict,
|
||||||
|
**kwargs_to_shard)
|
||||||
|
|
||||||
|
if rank == 0:
|
||||||
|
# print forward memory allocated and peak memory stats in kb
|
||||||
|
print(
|
||||||
|
f"forward memory allocated: {(torch.cuda.memory_allocated() - mem_stamp0) / 1024} kb, peak memory stats: {(torch.cuda.max_memory_allocated() - mem_stamp0) / 1024} kb"
|
||||||
|
)
|
||||||
|
|
||||||
|
# backward memory compare
|
||||||
|
grad_tensors = torch.ones_like(output)
|
||||||
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
mem_stamp0 = torch.cuda.memory_allocated()
|
||||||
|
torch.autograd.backward(output, grad_tensors)
|
||||||
|
|
||||||
|
if rank == 0:
|
||||||
|
# print backward memory allocated and peak memory stats in kb
|
||||||
|
print(
|
||||||
|
f"backward memory allocated: {(torch.cuda.memory_allocated() - mem_stamp0) / 1024} kb, peak memory stats: {(torch.cuda.max_memory_allocated() - mem_stamp0) / 1024} kb"
|
||||||
|
)
|
||||||
|
|
||||||
|
# estimated memory
|
||||||
|
metainfo = MetaInfo(target_node.strategies_vector[strategy_index],
|
||||||
|
target_node.graph.owning_module.get_submodule(target_node.target).__class__)
|
||||||
|
print("estimated memory:")
|
||||||
|
print(
|
||||||
|
f"forward activation: {metainfo.memory_cost.fwd.activation / 1024} kb, forward param: {metainfo.memory_cost.fwd.parameter / 1024} kb"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"forward temp: {metainfo.memory_cost.fwd.temp / 1024} kb, forward buffer: {metainfo.memory_cost.fwd.buffer / 1024} kb"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"backward activation: {metainfo.memory_cost.bwd.activation / 1024} kb, backward param: {metainfo.memory_cost.bwd.parameter / 1024} kb"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"backward temp: {metainfo.memory_cost.bwd.temp / 1024} kb, backward buffer: {metainfo.memory_cost.bwd.buffer / 1024} kb"
|
||||||
|
)
|
||||||
|
print("=======================")
|
|
@ -132,7 +132,6 @@ def check_linear_module_handler(rank, bias, world_size, port):
|
||||||
assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1]
|
assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class LinearModel(nn.Module):
|
class LinearModel(nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|
Loading…
Reference in New Issue