2022-08-23 06:23:08 +00:00
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
2022-08-19 08:51:38 +00:00
|
|
|
from abc import ABC, abstractmethod
|
|
|
|
from torch.fx.node import Node
|
2022-08-23 06:23:08 +00:00
|
|
|
from typing import Dict
|
2022-08-19 08:51:38 +00:00
|
|
|
from colossalai.device.device_mesh import DeviceMesh
|
|
|
|
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
|
|
|
from colossalai.tensor.sharding_spec import ShardingSpec
|
|
|
|
|
2022-08-23 06:23:08 +00:00
|
|
|
from .sharding_strategy import StrategiesVector
|
|
|
|
|
2022-08-19 08:51:38 +00:00
|
|
|
|
2022-08-23 06:23:08 +00:00
|
|
|
class OperatorHandler(ABC):
|
2022-08-19 08:51:38 +00:00
|
|
|
'''
|
2022-08-23 06:23:08 +00:00
|
|
|
The OperatorHandler is an abstract class used to generate every possible strategies for a operator node.
|
2022-08-19 08:51:38 +00:00
|
|
|
|
|
|
|
Argument:
|
|
|
|
input_node(Node): the input node in node argument list.
|
|
|
|
input_index(int): the index of input node in the node argument list.
|
|
|
|
weight(torch.Tensor): Weight of the node.
|
|
|
|
output_node(Node): Output_node is the output of the node.
|
|
|
|
device_mesh(DeviceMesh): A logical view of a physical mesh.
|
|
|
|
strategies_vector(StrategiesVector): all the strategies generated in this handler will be recorded into the strategies_vector.
|
|
|
|
shape_consistency_manager(ShapeConsistencyManager): ShapeConsistencyManager will give the resharding costs of the different sharding specs.
|
|
|
|
'''
|
|
|
|
|
2022-08-23 06:23:08 +00:00
|
|
|
def __init__(self, node: Node, device_mesh: DeviceMesh, strategies_vector: StrategiesVector,
|
2022-08-19 08:51:38 +00:00
|
|
|
shape_consistency_manager: ShapeConsistencyManager):
|
2022-08-23 06:23:08 +00:00
|
|
|
self.node = node
|
|
|
|
self.predecessor_node = list(node._input_nodes.keys())
|
|
|
|
self.successor_node = list(node.users.keys())
|
2022-08-19 08:51:38 +00:00
|
|
|
self.device_mesh = device_mesh
|
|
|
|
self.strategies_vector = strategies_vector
|
|
|
|
self.shape_consistency_manager = shape_consistency_manager
|
|
|
|
|
2022-08-23 06:23:08 +00:00
|
|
|
# find the module and its parameters associated with this node
|
|
|
|
# this can be used to compute the compute/communication/sharding cost
|
|
|
|
if self.node.op == 'call_module':
|
|
|
|
module = node.graph.owning_module.get_submodule(node.target)
|
|
|
|
named_parameters = list(module.named_parameters(recurse=False))
|
|
|
|
# convert named parameters from list to dict
|
|
|
|
named_parameters = {k: v for k, v in named_parameters}
|
|
|
|
else:
|
|
|
|
module = None
|
|
|
|
named_parameters = None
|
|
|
|
self.module = module
|
|
|
|
self.module_named_parameters = named_parameters
|
|
|
|
|
2022-08-19 08:51:38 +00:00
|
|
|
@abstractmethod
|
2022-08-23 06:23:08 +00:00
|
|
|
def register_strategy(self) -> StrategiesVector:
|
2022-08-19 08:51:38 +00:00
|
|
|
pass
|
|
|
|
|
2022-08-23 06:23:08 +00:00
|
|
|
def _generate_sharding_spec(self, tensor: torch.Tensor, dim_partition_dict: Dict[int, int]) -> ShardingSpec:
|
|
|
|
"""
|
|
|
|
Generate the sharding spec of the tensor based on the given dim_partition_dict
|
|
|
|
where the key is the tensor dimension and the value is the mesh dimension for sharding.
|
|
|
|
"""
|
2022-08-19 08:51:38 +00:00
|
|
|
sharding_spec = ShardingSpec(device_mesh=self.device_mesh,
|
|
|
|
entire_shape=tensor.shape,
|
|
|
|
dim_partition_dict=dim_partition_dict)
|
|
|
|
return sharding_spec
|
2022-08-22 02:32:17 +00:00
|
|
|
|
2022-08-23 06:23:08 +00:00
|
|
|
def _generate_resharding_costs(self, sharding_spec_for_input):
|
2022-08-22 02:32:17 +00:00
|
|
|
'''
|
|
|
|
Compute the resharding costs with this specific strategy.
|
|
|
|
|
|
|
|
Note: The resharding_cost of weight is NOT counted.
|
|
|
|
|
|
|
|
Argument:
|
|
|
|
resharding_costs(Dict[int, List[float]]): The resharding cost generated in this method will be appended into this dictionary.
|
|
|
|
Resharding_cost[i][j] means the cost of i-th argument in the output node argument list
|
|
|
|
with j-th strategy in its strategies_vector transforms to sharding spec wanted in this
|
|
|
|
strategy.
|
|
|
|
sharding_spec_for_input(ShardingSpec): ShardingSpec of the input node.
|
|
|
|
'''
|
|
|
|
# The resharding_cost of weight is counted due to sharing weight cases.
|
2022-08-23 06:23:08 +00:00
|
|
|
resharding_costs = {}
|
|
|
|
for input_node, input_spec in zip(self.predecessor_node, sharding_spec_for_input):
|
|
|
|
resharding_costs[input_node] = []
|
|
|
|
for strategy in input_node.strategies_vector:
|
|
|
|
_, _, resharding_cost = self.shape_consistency_manager.shape_consistency(strategy, input_spec)
|
|
|
|
resharding_costs[input_node].append(resharding_cost)
|
2022-08-22 02:32:17 +00:00
|
|
|
return resharding_cost
|