2022-06-15 08:36:46 +00:00
|
|
|
import torch
|
2022-07-11 07:51:48 +00:00
|
|
|
from colossalai.tensor import ColoTensorSpec, distspec, ProcessGroup, ComputeSpec, ComputePattern, ShardSpec
|
2022-06-15 08:36:46 +00:00
|
|
|
|
|
|
|
|
2022-07-07 05:37:31 +00:00
|
|
|
def weight_split(weight: torch.Tensor, dim: int) -> torch.nn.parameter.Parameter:
|
2022-07-06 09:19:26 +00:00
|
|
|
"""weight_split
|
|
|
|
split a nn.Parameter
|
2022-06-15 08:36:46 +00:00
|
|
|
|
2022-07-06 09:19:26 +00:00
|
|
|
Args:
|
|
|
|
weight (torch.nn.parameter.Parameter): a torch Parameter instance
|
|
|
|
dim (int): the dimension to be sharded along with
|
2022-06-15 08:36:46 +00:00
|
|
|
|
2022-07-06 09:19:26 +00:00
|
|
|
Returns:
|
|
|
|
_type_: _description_
|
2022-06-15 08:36:46 +00:00
|
|
|
"""
|
2022-07-06 09:19:26 +00:00
|
|
|
# Append a Tensor spec to target_module.weight.shard
|
|
|
|
# Convert to ColoTensor: colo_tensor = ColoTensor.from_torch_tensor(tensor, spec)
|
2022-07-07 05:37:31 +00:00
|
|
|
assert isinstance(weight, torch.Tensor), \
|
|
|
|
f'The type of the input tensor should be torch.nn.parameter' \
|
|
|
|
f'Your Input tensor is {type(weight)}'
|
2022-06-15 08:36:46 +00:00
|
|
|
|
2022-07-06 09:19:26 +00:00
|
|
|
# FIXME() I initialized a PG for this tensor. Only has TP comm group.
|
|
|
|
# we only consider the TP-only caes.
|
|
|
|
world_size = torch.distributed.get_world_size()
|
|
|
|
pg = ProcessGroup(tp_degree=world_size)
|
2022-06-15 08:36:46 +00:00
|
|
|
|
2022-07-11 07:51:48 +00:00
|
|
|
spec = ColoTensorSpec(pg, ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
2022-07-06 09:19:26 +00:00
|
|
|
# As you has constructed a Spec, why not directly convert the tensor to ColoTensor.
|
2022-07-07 05:37:31 +00:00
|
|
|
setattr(weight, "fx_attr", spec)
|
2022-07-06 09:19:26 +00:00
|
|
|
return weight
|
2022-06-15 08:36:46 +00:00
|
|
|
|
|
|
|
|
|
|
|
def column_shard_linear_pass(gm: torch.fx.GraphModule):
|
|
|
|
mod_graph = gm.graph
|
|
|
|
for node in mod_graph.nodes:
|
|
|
|
if node.op == "call_module":
|
|
|
|
target_module = node.graph.owning_module.get_submodule(node.target)
|
|
|
|
if isinstance(target_module, torch.nn.Linear):
|
2022-07-06 09:19:26 +00:00
|
|
|
target_module.weight = weight_split(target_module.weight, dim=0)
|
2022-06-15 08:36:46 +00:00
|
|
|
if target_module.bias is not None:
|
|
|
|
target_module.bias.data = weight_split(target_module.bias.data, dim=0)
|
|
|
|
|
|
|
|
gm.recompile()
|
|
|
|
return gm
|
|
|
|
|
|
|
|
|
|
|
|
def row_shard_linear_pass(gm: torch.fx.GraphModule):
|
|
|
|
mod_graph = gm.graph
|
|
|
|
for node in mod_graph.nodes:
|
|
|
|
if node.op == "call_module":
|
|
|
|
target_module = node.graph.owning_module.get_submodule(node.target)
|
|
|
|
if isinstance(target_module, torch.nn.Linear):
|
2022-07-06 09:19:26 +00:00
|
|
|
target_module.weight = weight_split(target_module.weight, dim=-1)
|
2022-07-07 05:37:31 +00:00
|
|
|
|
2022-06-15 08:36:46 +00:00
|
|
|
gm.recompile()
|
|
|
|
return gm
|
|
|
|
|
|
|
|
|
|
|
|
#TODO: add elementwise op process pass, then we can try to use column and row mixed strategy.
|