mirror of https://github.com/hpcaitech/ColossalAI
[hotfix] fx shard 1d pass bug fixing (#1220)
parent
11973d892d
commit
db1bef9032
|
@ -1,13 +1,9 @@
|
||||||
import torch
|
import torch
|
||||||
from torch.fx.node import map_arg
|
from torch.fx.node import map_arg
|
||||||
from torch.fx.node import Node
|
from colossalai.tensor import ColoTensorSpec, distspec, ProcessGroup, ComputeSpec, ComputePattern
|
||||||
from torch.fx.passes.split_module import split_module
|
|
||||||
|
|
||||||
import colossalai
|
|
||||||
from colossalai.tensor import ColoTensor, TensorSpec, distspec, ProcessGroup, ComputeSpec, ComputePattern
|
|
||||||
|
|
||||||
|
|
||||||
def weight_split(weight: torch.nn.parameter.Parameter, dim: int) -> torch.nn.parameter.Parameter:
|
def weight_split(weight: torch.Tensor, dim: int) -> torch.nn.parameter.Parameter:
|
||||||
"""weight_split
|
"""weight_split
|
||||||
split a nn.Parameter
|
split a nn.Parameter
|
||||||
|
|
||||||
|
@ -18,22 +14,20 @@ def weight_split(weight: torch.nn.parameter.Parameter, dim: int) -> torch.nn.par
|
||||||
Returns:
|
Returns:
|
||||||
_type_: _description_
|
_type_: _description_
|
||||||
"""
|
"""
|
||||||
#TODO: This func temporarily works with no materialization
|
|
||||||
# Append a Tensor spec to target_module.weight.shard
|
# Append a Tensor spec to target_module.weight.shard
|
||||||
# Convert to ColoTensor: colo_tensor = ColoTensor.from_torch_tensor(tensor, spec)
|
# Convert to ColoTensor: colo_tensor = ColoTensor.from_torch_tensor(tensor, spec)
|
||||||
# assert isinstance(weight, torch.nn.parameter.Parameter), \
|
assert isinstance(weight, torch.Tensor), \
|
||||||
# f'The type of the input tensor should be torch.nn.parameter' \
|
f'The type of the input tensor should be torch.nn.parameter' \
|
||||||
# f'Your Input tensor is {type(weight)}'
|
f'Your Input tensor is {type(weight)}'
|
||||||
|
|
||||||
# FIXME() I initialized a PG for this tensor. Only has TP comm group.
|
# FIXME() I initialized a PG for this tensor. Only has TP comm group.
|
||||||
# we only consider the TP-only caes.
|
# we only consider the TP-only caes.
|
||||||
world_size = torch.distributed.get_world_size()
|
world_size = torch.distributed.get_world_size()
|
||||||
pg = ProcessGroup(tp_degree=world_size)
|
pg = ProcessGroup(tp_degree=world_size)
|
||||||
|
|
||||||
spec = TensorSpec(distspec.shard(pg, [dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
spec = ColoTensorSpec(pg, distspec.shard([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||||
# As you has constructed a Spec, why not directly convert the tensor to ColoTensor.
|
# As you has constructed a Spec, why not directly convert the tensor to ColoTensor.
|
||||||
# setattr(weight, "fx_attr", spec)
|
setattr(weight, "fx_attr", spec)
|
||||||
weight.data = ColoTensor(data=weight.data, spec=spec)
|
|
||||||
return weight
|
return weight
|
||||||
|
|
||||||
|
|
||||||
|
@ -58,6 +52,7 @@ def row_shard_linear_pass(gm: torch.fx.GraphModule):
|
||||||
target_module = node.graph.owning_module.get_submodule(node.target)
|
target_module = node.graph.owning_module.get_submodule(node.target)
|
||||||
if isinstance(target_module, torch.nn.Linear):
|
if isinstance(target_module, torch.nn.Linear):
|
||||||
target_module.weight = weight_split(target_module.weight, dim=-1)
|
target_module.weight = weight_split(target_module.weight, dim=-1)
|
||||||
|
|
||||||
gm.recompile()
|
gm.recompile()
|
||||||
return gm
|
return gm
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue