mirror of https://github.com/hpcaitech/ColossalAI
[fx] temporarily used (#1215)
parent
ae7d3f4927
commit
291e22aac6
|
@ -4,75 +4,37 @@ from torch.fx.node import Node
|
||||||
from torch.fx.passes.split_module import split_module
|
from torch.fx.passes.split_module import split_module
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.context import ParallelMode
|
from colossalai.tensor import ColoTensor, TensorSpec, distspec, ProcessGroup, ComputeSpec, ComputePattern
|
||||||
from colossalai.core import global_context as gpc
|
|
||||||
|
|
||||||
|
|
||||||
def all_gather_function(input_):
|
def weight_split(weight: torch.nn.parameter.Parameter, dim: int) -> torch.nn.parameter.Parameter:
|
||||||
world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
|
"""weight_split
|
||||||
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
split a nn.Parameter
|
||||||
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
|
|
||||||
tensor_list[rank] = input_
|
|
||||||
group = gpc.get_group(ParallelMode.PARALLEL_1D)
|
|
||||||
torch.distributed.all_gather(tensor_list, input_, group=group)
|
|
||||||
output = torch.cat(tensor_list, dim=-1).contiguous()
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
weight (torch.nn.parameter.Parameter): a torch Parameter instance
|
||||||
|
dim (int): the dimension to be sharded along with
|
||||||
|
|
||||||
def all_reduce_function(input_):
|
Returns:
|
||||||
if gpc.get_world_size(ParallelMode.PARALLEL_1D) == 1:
|
_type_: _description_
|
||||||
return input_
|
|
||||||
torch.distributed.all_reduce(input_, group=gpc.get_group(ParallelMode.PARALLEL_1D))
|
|
||||||
return input_
|
|
||||||
|
|
||||||
|
|
||||||
def weight_split(weight, dim):
|
|
||||||
#TODO: this function will be refactored by using ColoTensor dist_spec when a stable reshaper feature is ready to use.
|
|
||||||
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
|
|
||||||
shape = weight.shape
|
|
||||||
length = shape[dim] // num_partition
|
|
||||||
sharded_weight_list = []
|
|
||||||
for i in range(num_partition):
|
|
||||||
sharded_weight_list.append(weight.narrow(dim, i * length, length))
|
|
||||||
return sharded_weight_list[gpc.get_local_rank(ParallelMode.PARALLEL_1D)]
|
|
||||||
|
|
||||||
|
|
||||||
def replace_all_uses_except_replaced(node, replace_node):
|
|
||||||
"""
|
"""
|
||||||
Replace all uses of ``node`` in the Graph with the Node ``replace_node``,
|
#TODO: This func temporarily works with no materialization
|
||||||
except the user of ``node`` is ``replace_node``.
|
# Append a Tensor spec to target_module.weight.shard
|
||||||
|
# Convert to ColoTensor: colo_tensor = ColoTensor.from_torch_tensor(tensor, spec)
|
||||||
|
# assert isinstance(weight, torch.nn.parameter.Parameter), \
|
||||||
|
# f'The type of the input tensor should be torch.nn.parameter' \
|
||||||
|
# f'Your Input tensor is {type(weight)}'
|
||||||
|
|
||||||
Args:
|
# FIXME() I initialized a PG for this tensor. Only has TP comm group.
|
||||||
|
# we only consider the TP-only caes.
|
||||||
|
world_size = torch.distributed.get_world_size()
|
||||||
|
pg = ProcessGroup(tp_degree=world_size)
|
||||||
|
|
||||||
replace_node (Node): The node to replace all uses of ``node`` with.
|
spec = TensorSpec(distspec.shard(pg, [dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||||
|
# As you has constructed a Spec, why not directly convert the tensor to ColoTensor.
|
||||||
Returns:
|
# setattr(weight, "fx_attr", spec)
|
||||||
|
weight.data = ColoTensor(data=weight.data, spec=spec)
|
||||||
The list of Nodes on which this change was made.
|
return weight
|
||||||
"""
|
|
||||||
to_process = list(node.users)
|
|
||||||
for use_node in to_process:
|
|
||||||
if use_node == replace_node:
|
|
||||||
continue
|
|
||||||
|
|
||||||
def may_replace_node(n):
|
|
||||||
if n == node:
|
|
||||||
return replace_node
|
|
||||||
else:
|
|
||||||
return n
|
|
||||||
|
|
||||||
new_args = map_arg(use_node.args, may_replace_node)
|
|
||||||
new_kwargs = map_arg(use_node.kwargs, may_replace_node)
|
|
||||||
use_node._args = new_args
|
|
||||||
use_node._kwargs = new_kwargs
|
|
||||||
for old_use in use_node._input_nodes.keys():
|
|
||||||
old_use.users.pop(use_node)
|
|
||||||
use_node._input_nodes = {}
|
|
||||||
map_arg(use_node._args, lambda n: use_node._input_nodes.setdefault(n))
|
|
||||||
map_arg(use_node._kwargs, lambda n: use_node._input_nodes.setdefault(n))
|
|
||||||
for new_use in use_node._input_nodes.keys():
|
|
||||||
new_use.users.setdefault(use_node)
|
|
||||||
return to_process
|
|
||||||
|
|
||||||
|
|
||||||
def column_shard_linear_pass(gm: torch.fx.GraphModule):
|
def column_shard_linear_pass(gm: torch.fx.GraphModule):
|
||||||
|
@ -81,14 +43,10 @@ def column_shard_linear_pass(gm: torch.fx.GraphModule):
|
||||||
if node.op == "call_module":
|
if node.op == "call_module":
|
||||||
target_module = node.graph.owning_module.get_submodule(node.target)
|
target_module = node.graph.owning_module.get_submodule(node.target)
|
||||||
if isinstance(target_module, torch.nn.Linear):
|
if isinstance(target_module, torch.nn.Linear):
|
||||||
target_module.weight.data = weight_split(target_module.weight.data, dim=0)
|
target_module.weight = weight_split(target_module.weight, dim=0)
|
||||||
if target_module.bias is not None:
|
if target_module.bias is not None:
|
||||||
target_module.bias.data = weight_split(target_module.bias.data, dim=0)
|
target_module.bias.data = weight_split(target_module.bias.data, dim=0)
|
||||||
|
|
||||||
# inserting communication node after the sharded linear node
|
|
||||||
with mod_graph.inserting_after(node):
|
|
||||||
new_node = mod_graph.create_node('call_function', all_gather_function, args=(node,))
|
|
||||||
replace_all_uses_except_replaced(node, new_node)
|
|
||||||
gm.recompile()
|
gm.recompile()
|
||||||
return gm
|
return gm
|
||||||
|
|
||||||
|
@ -99,20 +57,7 @@ def row_shard_linear_pass(gm: torch.fx.GraphModule):
|
||||||
if node.op == "call_module":
|
if node.op == "call_module":
|
||||||
target_module = node.graph.owning_module.get_submodule(node.target)
|
target_module = node.graph.owning_module.get_submodule(node.target)
|
||||||
if isinstance(target_module, torch.nn.Linear):
|
if isinstance(target_module, torch.nn.Linear):
|
||||||
target_module.weight.data = weight_split(target_module.weight.data, dim=-1)
|
target_module.weight = weight_split(target_module.weight, dim=-1)
|
||||||
|
|
||||||
# insert input sharding node before the sharded linear node
|
|
||||||
with mod_graph.inserting_before(node):
|
|
||||||
input_node_list = list(node._input_nodes.keys())
|
|
||||||
assert len(input_node_list) == 1, 'linear forward must have and only have one input tensor.'
|
|
||||||
input_node = input_node_list[0]
|
|
||||||
new_input_node = mod_graph.create_node('call_function', weight_split, args=(input_node, -1))
|
|
||||||
replace_all_uses_except_replaced(input_node, new_input_node)
|
|
||||||
|
|
||||||
# inserting communication node after the sharded linear node
|
|
||||||
with mod_graph.inserting_after(node):
|
|
||||||
new_node = mod_graph.create_node('call_function', all_reduce_function, args=(node,))
|
|
||||||
replace_all_uses_except_replaced(node, new_node)
|
|
||||||
gm.recompile()
|
gm.recompile()
|
||||||
return gm
|
return gm
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue