mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] added new linear module handler (#1616)
parent
170fa81095
commit
d925122020
|
@ -0,0 +1,139 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from .node_handler import ModuleHandler, NodeHandler
|
||||||
|
from ..sharding_strategy import ShardingStrategy_V2, StrategyGenerator_V2, OperationDataType, OperationData
|
||||||
|
from typing import List, Dict
|
||||||
|
from .registry import operator_registry
|
||||||
|
|
||||||
|
__all__ = ['LinearModuleHandler']
|
||||||
|
|
||||||
|
|
||||||
|
class DotProductStrategyGenerator(StrategyGenerator_V2):
|
||||||
|
"""TODO: to be implemented"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class MatVecStrategyGenerator(StrategyGenerator_V2):
|
||||||
|
"""TODO: to be implemented"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class LinearProjectionStrategyGenerator(StrategyGenerator_V2):
|
||||||
|
|
||||||
|
def update_compute_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2:
|
||||||
|
"""TODO: to be implemented"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def update_memory_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2:
|
||||||
|
"""TODO: to be implemented"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def generate(self, operand_mapping: Dict[str, OperationData]) -> List[ShardingStrategy_V2]:
|
||||||
|
"""TODO: to be implemented"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def validate(self, *args, **kwargs) -> bool:
|
||||||
|
"""TODO: to be implemented"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class BatchedMatMulStrategyGenerator(StrategyGenerator_V2):
|
||||||
|
"""TODO: to be implemented"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@operator_registry.register(torch.nn.Linear)
|
||||||
|
class LinearModuleHandler(ModuleHandler):
|
||||||
|
"""
|
||||||
|
A LinearModuleHandler which deals with the sharding strategies for nn.Linear module.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def register_strategy_generator(self) -> List[StrategyGenerator_V2]:
|
||||||
|
generators = []
|
||||||
|
generators.append(LinearProjectionStrategyGenerator(self.device_mesh))
|
||||||
|
return generators
|
||||||
|
|
||||||
|
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
|
||||||
|
# use transposed shape for strategies
|
||||||
|
# the strategies will be transformed back to its original shape in self.post_process
|
||||||
|
physical_input_operand = OperationData(name=str(self.node.args[0]),
|
||||||
|
type=OperationDataType.ARG,
|
||||||
|
data=self.node.args[0]._meta_data)
|
||||||
|
physical_other_operand = OperationData(name="weight",
|
||||||
|
type=OperationDataType.PARAM,
|
||||||
|
data=self.named_parameters['weight'],
|
||||||
|
logical_shape=self.named_parameters['weight'].shape[::-1])
|
||||||
|
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
|
||||||
|
|
||||||
|
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
|
||||||
|
|
||||||
|
if self.named_parameters['bias'] is not None:
|
||||||
|
physical_bias_operand = OperationData(name="bias",
|
||||||
|
type=OperationDataType.PARAM,
|
||||||
|
data=self.named_parameters['bias'])
|
||||||
|
mapping['bias'] = physical_bias_operand
|
||||||
|
return mapping
|
||||||
|
|
||||||
|
def post_process(self, strategy: ShardingStrategy_V2):
|
||||||
|
"""
|
||||||
|
Convert the sharding spec of the weight parameter back to its original shape.
|
||||||
|
"""
|
||||||
|
for op_data, sharding_spec in strategy.input_sharding_specs.items():
|
||||||
|
if op_data.name == "weight":
|
||||||
|
assert op_data.logical_shape != op_data.data.shape
|
||||||
|
dim_partition_dict = sharding_spec.dim_partition_dict
|
||||||
|
# switch first and last dim of the linear module weight
|
||||||
|
dim_partition_dict[0], dim_partition_dict[-1] = dim_partition_dict[-1], dim_partition_dict[0]
|
||||||
|
|
||||||
|
# re-init the sharding spec
|
||||||
|
sharding_spec.__init__(sharding_spec.device_mesh, sharding_spec.entire_shape, dim_partition_dict)
|
||||||
|
return strategy
|
||||||
|
|
||||||
|
|
||||||
|
@operator_registry.register(F.linear)
|
||||||
|
class LinearFunctionHandler(NodeHandler):
|
||||||
|
"""
|
||||||
|
A LinearModuleHandler which deals with the sharding strategies for nn.Linear module.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def register_strategy_generator(self) -> List[StrategyGenerator_V2]:
|
||||||
|
generators = []
|
||||||
|
generators.append(LinearProjectionStrategyGenerator(self.device_mesh))
|
||||||
|
return generators
|
||||||
|
|
||||||
|
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
|
||||||
|
# use transposed shape for strategies
|
||||||
|
# the strategies will be transformed back to its original shape in self.post_process
|
||||||
|
physical_input_operand = OperationData(name=str(self.node.args[0]),
|
||||||
|
type=OperationDataType.ARG,
|
||||||
|
data=self.node.args[0]._meta_data)
|
||||||
|
physical_other_operand = OperationData(name=str(self.node.args[1]),
|
||||||
|
type=OperationDataType.ARG,
|
||||||
|
data=self.node.args[1]._meta_data,
|
||||||
|
logical_shape=self.node.args[1]._meta_data.shape[::-1])
|
||||||
|
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
|
||||||
|
|
||||||
|
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
|
||||||
|
|
||||||
|
if self.node.args[2] is not None:
|
||||||
|
physical_bias_operand = OperationData(name=str(self.node.args[2]),
|
||||||
|
type=OperationDataType.ARG,
|
||||||
|
data=self.node.args[2]._meta_data)
|
||||||
|
mapping['bias'] = physical_bias_operand
|
||||||
|
return mapping
|
||||||
|
|
||||||
|
def post_process(self, strategy: ShardingStrategy_V2):
|
||||||
|
"""
|
||||||
|
Convert the sharding spec of the weight parameter back to its original shape.
|
||||||
|
"""
|
||||||
|
for op_data, sharding_spec in strategy.input_sharding_specs.items():
|
||||||
|
if op_data.name == str(self.node.args[1]):
|
||||||
|
assert op_data.logical_shape != op_data.data.shape
|
||||||
|
dim_partition_dict = sharding_spec.dim_partition_dict
|
||||||
|
# switch first and last dim of the linear module weight
|
||||||
|
dim_partition_dict[0], dim_partition_dict[-1] = dim_partition_dict[-1], dim_partition_dict[0]
|
||||||
|
|
||||||
|
# re-init the sharding spec
|
||||||
|
sharding_spec.__init__(sharding_spec.device_mesh, sharding_spec.entire_shape, dim_partition_dict)
|
||||||
|
return strategy
|
|
@ -2,7 +2,7 @@ from abc import ABC, abstractmethod
|
||||||
from torch.fx.node import Node
|
from torch.fx.node import Node
|
||||||
from colossalai.device.device_mesh import DeviceMesh
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
from ..sharding_strategy import StrategiesVector, Operand, StrategyGenerator_V2
|
from ..sharding_strategy import ShardingStrategy, ShardingStrategy_V2, StrategiesVector, OperationData, StrategyGenerator_V2
|
||||||
|
|
||||||
|
|
||||||
class NodeHandler(ABC):
|
class NodeHandler(ABC):
|
||||||
|
@ -36,8 +36,15 @@ class NodeHandler(ABC):
|
||||||
for generator in self.strategy_generator:
|
for generator in self.strategy_generator:
|
||||||
strategies = generator.generate(operand_mapping)
|
strategies = generator.generate(operand_mapping)
|
||||||
self.strategies_vector.extend(strategies)
|
self.strategies_vector.extend(strategies)
|
||||||
|
|
||||||
|
self.strategies_vector = map(self.post_process, self.strategies_vector)
|
||||||
return self.strategies_vector
|
return self.strategies_vector
|
||||||
|
|
||||||
|
def post_process(self, strategy: ShardingStrategy_V2):
|
||||||
|
# tranform the strategy generated
|
||||||
|
# e.g. to process the sharding strategy for the transposed weights
|
||||||
|
return strategy
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def register_strategy_generator(self) -> List[StrategyGenerator_V2]:
|
def register_strategy_generator(self) -> List[StrategyGenerator_V2]:
|
||||||
"""
|
"""
|
||||||
|
@ -46,21 +53,40 @@ class NodeHandler(ABC):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_operand_mapping(self) -> Dict[str, Operand]:
|
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
|
||||||
"""
|
"""
|
||||||
Returns the mapping between the logical operand name to its physical operands.
|
Returns the mapping between the logical operation data to its physical data.
|
||||||
A logical operand is defined by the strategy generator, for example, a matrix multiplication
|
A logical operation data is a data associated with an operation, which can be input and output. It is
|
||||||
operation has two operands "input" and "other". For a nn.Linear module, the physical operand for "input" is
|
defined by the strategy generator, for example, a matrix multiplication operation has two operands "input"
|
||||||
the module input and the physical operand for "other" is the module weight.
|
and "other" and one result "output". For a nn.Linear module, the physical operand for "input" is
|
||||||
|
the module input, the physical operand for "other" is the module weight, and the physical result for "output"
|
||||||
|
is the module output.
|
||||||
Note that the operand name is specified by the StrategyGenerator object.
|
Note that the operand name is specified by the StrategyGenerator object.
|
||||||
|
|
||||||
For example:
|
For example:
|
||||||
|
|
||||||
# for a linear layer
|
# for a linear layer
|
||||||
mapping = {
|
mapping = {
|
||||||
"input": Operand(name=str(self.node.args[0]), type=OperandType.ARG),
|
"input": Operand(name=str(self.node.args[0]), type=OperationDataType.ARG, data=self.node.args[0]._meta_data),
|
||||||
"other": Operand(name="weight", type=OperandType.PARAM),
|
"other": Operand(name="weight", type=OperationDataType.PARAM, data=self.named_parameters['weight']),
|
||||||
"bias": Operand(name="bias", type=OperandType.PARAM)
|
"bias": Operand(name="bias", type=OperationDataType.PARAM, data=self.named_parameters['bias']),
|
||||||
|
"output": Operand(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data),
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ModuleHandler(NodeHandler):
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs) -> None:
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
# set attributes to access module parameters for convenience
|
||||||
|
assert self.node.graph.owning_module is not None, \
|
||||||
|
f'The graph is not associated with a module, please make sure it can be used to instantiate a GraphModule object.'
|
||||||
|
module = self.node.graph.owning_module.get_submodule(self.node.target)
|
||||||
|
named_parameters = list(module.named_parameters(recurse=False))
|
||||||
|
# convert named parameters from list to dict
|
||||||
|
named_parameters = {k: v for k, v in named_parameters}
|
||||||
|
self.module = module
|
||||||
|
self.named_parameters = named_parameters
|
||||||
|
|
|
@ -1,6 +1,10 @@
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
import operator
|
||||||
|
import torch
|
||||||
|
from functools import reduce
|
||||||
|
|
||||||
from colossalai.device.device_mesh import DeviceMesh
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||||
from typing import Dict, List, Union, Tuple, Any
|
from typing import Dict, List, Union, Tuple, Any
|
||||||
|
@ -40,18 +44,35 @@ class ShardingStrategy:
|
||||||
input_shardings: List[ShardingSpec] = None
|
input_shardings: List[ShardingSpec] = None
|
||||||
|
|
||||||
|
|
||||||
class OperandType(Enum):
|
class OperationDataType(Enum):
|
||||||
"""
|
"""
|
||||||
An operand can come from the argument list of an operator or the parameter list of a module.
|
An operation can come from the argument list of an operator or the parameter list of a module.
|
||||||
"""
|
"""
|
||||||
ARG = 0
|
ARG = 0
|
||||||
PARAM = 1
|
PARAM = 1
|
||||||
|
OUTPUT = 2
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Operand:
|
class OperationData:
|
||||||
|
"""
|
||||||
|
OperationData is the data related to an operator, the data can be the operand or the output.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): the name of the operation-related data
|
||||||
|
type (OperationDataType): the type of the operation data
|
||||||
|
data (torch.Tensor): the value for this data, usually it is a meta tensor.
|
||||||
|
logical_shape (Tuple[int]): the logical shape of the data, it can be different from the its actual shape in memory.
|
||||||
|
"""
|
||||||
name: str
|
name: str
|
||||||
type: OperandType
|
type: OperationDataType
|
||||||
|
data: torch.Tensor
|
||||||
|
logical_shape: Tuple[int] = None
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
# if no logical shape is specified, use the data shape as the logical shape
|
||||||
|
if self.logical_shape is None:
|
||||||
|
self.logical_shape = self.data.shape
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -69,6 +90,20 @@ class TrainCycleItem:
|
||||||
total: Any
|
total: Any
|
||||||
|
|
||||||
|
|
||||||
|
class CommunicationType(Enum):
|
||||||
|
FWD_ALL_REDUCE = 0
|
||||||
|
BWD_ALL_REDUCE = 1
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CommunicationAction:
|
||||||
|
"""
|
||||||
|
The actions
|
||||||
|
"""
|
||||||
|
type: CommunicationType
|
||||||
|
mesh_dim: int
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ShardingStrategy_V2:
|
class ShardingStrategy_V2:
|
||||||
"""
|
"""
|
||||||
|
@ -86,12 +121,35 @@ class ShardingStrategy_V2:
|
||||||
strategy.(default to None)
|
strategy.(default to None)
|
||||||
"""
|
"""
|
||||||
name: str
|
name: str
|
||||||
output_sharding_spec: ShardingSpec
|
sharding_specs: Dict[OperationData, ShardingSpec] = None
|
||||||
compute_cost: TrainCycleItem = None
|
compute_cost: TrainCycleItem = None
|
||||||
communication_cost: TrainCycleItem = None
|
communication_cost: TrainCycleItem = None
|
||||||
memory_cost: TrainCycleItem = None
|
memory_cost: TrainCycleItem = None
|
||||||
input_sharding_specs: Dict[Operand, ShardingSpec] = None
|
input_resharding_costs: Dict[OperationData, List[float]] = None
|
||||||
input_resharding_costs: Dict[Operand, List[float]] = None
|
communication_actions: Dict[OperationData, List[CommunicationAction]] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_sharding_specs(self) -> Dict[OperationData, ShardingSpec]:
|
||||||
|
specs = {}
|
||||||
|
specs.update(self._get_sharding_spec(OperationDataType.ARG))
|
||||||
|
specs.update(self._get_sharding_spec(OperationDataType.PARAM))
|
||||||
|
return specs
|
||||||
|
|
||||||
|
@property
|
||||||
|
def argument_sharding_specs(self) -> Dict[OperationData, ShardingSpec]:
|
||||||
|
return self._get_sharding_spec(OperationDataType.ARG)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def param_sharding_specs(self) -> Dict[OperationData, ShardingSpec]:
|
||||||
|
return self._get_sharding_spec(OperationDataType.PARAM)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_sharding_specs(self) -> Dict[OperationData, ShardingSpec]:
|
||||||
|
return self._get_sharding_spec(OperationDataType.OUTPUT)
|
||||||
|
|
||||||
|
def _get_sharding_spec(self, operation_data_type: OperationDataType):
|
||||||
|
specs = {k: v for k, v in self.sharding_specs.items() if k.type == operation_data_type}
|
||||||
|
return specs
|
||||||
|
|
||||||
|
|
||||||
class StrategyGenerator_V2(ABC):
|
class StrategyGenerator_V2(ABC):
|
||||||
|
@ -104,9 +162,57 @@ class StrategyGenerator_V2(ABC):
|
||||||
def __init__(self, device_mesh: DeviceMesh):
|
def __init__(self, device_mesh: DeviceMesh):
|
||||||
self.device_mesh = device_mesh
|
self.device_mesh = device_mesh
|
||||||
|
|
||||||
@abstractmethod
|
def update_communication_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2:
|
||||||
def generate(self, operand_mapping: Dict[str, Operand]) -> List[ShardingStrategy_V2]:
|
|
||||||
"""
|
"""
|
||||||
|
Compute the communication cost involved in the forward and backward iteration.
|
||||||
|
"""
|
||||||
|
|
||||||
|
comm_cost = TrainCycleItem(fwd=0, bwd=0)
|
||||||
|
|
||||||
|
def _compute_and_add(data: OperationData, action: CommunicationAction):
|
||||||
|
sharded_shape = strategy.sharding_specs[data].get_sharded_shape_per_device()
|
||||||
|
dtype = operand.data.dtype
|
||||||
|
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||||
|
num_bytes = size_per_elem_bytes * reduce(operator.mul, sharded_shape)
|
||||||
|
cost = self.device_mesh.all_reduce_cost(num_bytes=num_bytes, mesh_dim=action.mesh_dim)
|
||||||
|
|
||||||
|
# compute the fwd
|
||||||
|
if action.type == CommunicationType.FWD_ALL_REDUCE:
|
||||||
|
comm_cost.fwd += cost
|
||||||
|
elif action.type == CommunicationType.BWD_ALL_REDUCE:
|
||||||
|
comm_cost.fwd += cost
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Found unknown CommunicationType {action.type}")
|
||||||
|
|
||||||
|
# check if communication action exists
|
||||||
|
# if so, loop over each action and compute the cost of each action
|
||||||
|
if strategy.communication_actions is not None:
|
||||||
|
for operand, actions in strategy.communication_actions:
|
||||||
|
for action in actions:
|
||||||
|
_compute_and_add(operand, action)
|
||||||
|
|
||||||
|
# update the communication cost attribute in-place
|
||||||
|
strategy.communication_cost = comm_cost
|
||||||
|
return strategy
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def update_compute_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2:
|
||||||
|
"""
|
||||||
|
Customize this method to compute the computation flops.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def update_memory_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2:
|
||||||
|
"""
|
||||||
|
Customize this method to compute the memory cost in bytes.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def generate(self, operand_mapping: Dict[str, OperationData]) -> List[ShardingStrategy_V2]:
|
||||||
|
"""
|
||||||
|
Generate all possible sharding strategies for this operation.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,104 @@
|
||||||
|
from colossalai.fx.tracer.meta_patch.patched_module import linear
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from colossalai.fx import ColoTracer, ColoGraphModule
|
||||||
|
from colossalai.auto_parallel.solver.op_handler.dot_handler_v2 import LinearModuleHandler, LinearFunctionHandler
|
||||||
|
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||||
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
|
|
||||||
|
|
||||||
|
def test_linear_module_handler():
|
||||||
|
model = nn.Sequential(nn.Linear(10, 20).to('meta'))
|
||||||
|
tracer = ColoTracer()
|
||||||
|
graph = tracer.trace(model, meta_args={"input": torch.rand(4, 10).to('meta')})
|
||||||
|
gm = ColoGraphModule(model, graph)
|
||||||
|
physical_mesh_id = torch.arange(0, 4)
|
||||||
|
|
||||||
|
print(graph)
|
||||||
|
mesh_shape = (2, 2)
|
||||||
|
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||||
|
linear_mod_node = list(graph.nodes)[1]
|
||||||
|
strategies_vector = StrategiesVector(linear_mod_node)
|
||||||
|
|
||||||
|
# build handler
|
||||||
|
handler = LinearModuleHandler(node=linear_mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector)
|
||||||
|
|
||||||
|
# check operation data mapping
|
||||||
|
mapping = handler.get_operation_data_mapping()
|
||||||
|
|
||||||
|
for name, op_data in mapping.items():
|
||||||
|
op_data: OperationData
|
||||||
|
# make sure they have valid values
|
||||||
|
assert op_data.logical_shape is not None
|
||||||
|
assert op_data.data is not None
|
||||||
|
|
||||||
|
assert mapping['input'].name == "input_1"
|
||||||
|
assert mapping['input'].data.is_meta
|
||||||
|
assert mapping['input'].data.shape == torch.Size([4, 10])
|
||||||
|
assert mapping['input'].type == OperationDataType.ARG
|
||||||
|
assert mapping['input'].logical_shape == torch.Size([4, 10])
|
||||||
|
|
||||||
|
assert mapping['other'].name == "weight"
|
||||||
|
assert mapping['other'].data.is_meta
|
||||||
|
assert mapping['other'].data.shape == torch.Size([20, 10])
|
||||||
|
assert mapping['other'].type == OperationDataType.PARAM
|
||||||
|
assert mapping['other'].logical_shape == torch.Size([10, 20])
|
||||||
|
|
||||||
|
assert mapping['bias'].name == "bias"
|
||||||
|
assert mapping['bias'].data.is_meta
|
||||||
|
assert mapping['bias'].data.shape == torch.Size([20])
|
||||||
|
assert mapping['bias'].type == OperationDataType.PARAM
|
||||||
|
assert mapping['other'].logical_shape == torch.Size([10, 20])
|
||||||
|
|
||||||
|
assert mapping['output'].name == "_0"
|
||||||
|
assert mapping['output'].data.is_meta
|
||||||
|
assert mapping['output'].data.shape == torch.Size([4, 20])
|
||||||
|
assert mapping['output'].type == OperationDataType.OUTPUT
|
||||||
|
|
||||||
|
|
||||||
|
def test_linear_function_handler():
|
||||||
|
model = nn.Linear(10, 20).to('meta')
|
||||||
|
tracer = ColoTracer()
|
||||||
|
graph = tracer.trace(model, meta_args={"input": torch.rand(4, 10).to('meta')})
|
||||||
|
gm = ColoGraphModule(model, graph)
|
||||||
|
physical_mesh_id = torch.arange(0, 4)
|
||||||
|
|
||||||
|
print(graph)
|
||||||
|
mesh_shape = (2, 2)
|
||||||
|
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||||
|
linear_func_node = list(graph.nodes)[3]
|
||||||
|
strategies_vector = StrategiesVector(linear_func_node)
|
||||||
|
|
||||||
|
# build handler
|
||||||
|
handler = LinearFunctionHandler(node=linear_func_node, device_mesh=device_mesh, strategies_vector=strategies_vector)
|
||||||
|
|
||||||
|
# # check operation data mapping
|
||||||
|
mapping = handler.get_operation_data_mapping()
|
||||||
|
|
||||||
|
assert mapping['input'].name == "input_1"
|
||||||
|
assert mapping['input'].data.is_meta
|
||||||
|
assert mapping['input'].data.shape == torch.Size([4, 10])
|
||||||
|
assert mapping['input'].type == OperationDataType.ARG
|
||||||
|
assert mapping['input'].logical_shape == torch.Size([4, 10])
|
||||||
|
|
||||||
|
assert mapping['other'].name == "weight"
|
||||||
|
assert mapping['other'].data.is_meta
|
||||||
|
assert mapping['other'].data.shape == torch.Size([20, 10])
|
||||||
|
assert mapping['other'].type == OperationDataType.ARG
|
||||||
|
assert mapping['other'].logical_shape == torch.Size([10, 20])
|
||||||
|
|
||||||
|
assert mapping['bias'].name == "bias"
|
||||||
|
assert mapping['bias'].data.is_meta
|
||||||
|
assert mapping['bias'].data.shape == torch.Size([20])
|
||||||
|
assert mapping['bias'].type == OperationDataType.ARG
|
||||||
|
assert mapping['other'].logical_shape == torch.Size([10, 20])
|
||||||
|
|
||||||
|
assert mapping['output'].name == "linear"
|
||||||
|
assert mapping['output'].data.is_meta
|
||||||
|
assert mapping['output'].data.shape == torch.Size([4, 20])
|
||||||
|
assert mapping['output'].type == OperationDataType.OUTPUT
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# test_linear_module_handler()
|
||||||
|
test_linear_function_handler()
|
Loading…
Reference in New Issue