diff --git a/colossalai/fx/passes/shard_1d_pass.py b/colossalai/fx/passes/shard_1d_pass.py index 43c056b60..9cbdd3d45 100644 --- a/colossalai/fx/passes/shard_1d_pass.py +++ b/colossalai/fx/passes/shard_1d_pass.py @@ -1,13 +1,9 @@ import torch from torch.fx.node import map_arg -from torch.fx.node import Node -from torch.fx.passes.split_module import split_module - -import colossalai -from colossalai.tensor import ColoTensor, TensorSpec, distspec, ProcessGroup, ComputeSpec, ComputePattern +from colossalai.tensor import ColoTensorSpec, distspec, ProcessGroup, ComputeSpec, ComputePattern -def weight_split(weight: torch.nn.parameter.Parameter, dim: int) -> torch.nn.parameter.Parameter: +def weight_split(weight: torch.Tensor, dim: int) -> torch.nn.parameter.Parameter: """weight_split split a nn.Parameter @@ -18,22 +14,20 @@ def weight_split(weight: torch.nn.parameter.Parameter, dim: int) -> torch.nn.par Returns: _type_: _description_ """ - #TODO: This func temporarily works with no materialization # Append a Tensor spec to target_module.weight.shard # Convert to ColoTensor: colo_tensor = ColoTensor.from_torch_tensor(tensor, spec) - # assert isinstance(weight, torch.nn.parameter.Parameter), \ - # f'The type of the input tensor should be torch.nn.parameter' \ - # f'Your Input tensor is {type(weight)}' + assert isinstance(weight, torch.Tensor), \ + f'The type of the input tensor should be torch.nn.parameter' \ + f'Your Input tensor is {type(weight)}' # FIXME() I initialized a PG for this tensor. Only has TP comm group. # we only consider the TP-only caes. world_size = torch.distributed.get_world_size() pg = ProcessGroup(tp_degree=world_size) - spec = TensorSpec(distspec.shard(pg, [dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) + spec = ColoTensorSpec(pg, distspec.shard([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) # As you has constructed a Spec, why not directly convert the tensor to ColoTensor. - # setattr(weight, "fx_attr", spec) - weight.data = ColoTensor(data=weight.data, spec=spec) + setattr(weight, "fx_attr", spec) return weight @@ -58,6 +52,7 @@ def row_shard_linear_pass(gm: torch.fx.GraphModule): target_module = node.graph.owning_module.get_submodule(node.target) if isinstance(target_module, torch.nn.Linear): target_module.weight = weight_split(target_module.weight, dim=-1) + gm.recompile() return gm