diff --git a/colossalai/shardformer/layer/layers.py b/colossalai/shardformer/layer/layers.py index 87d24f18e..586aec124 100644 --- a/colossalai/shardformer/layer/layers.py +++ b/colossalai/shardformer/layer/layers.py @@ -329,7 +329,11 @@ class Linear1D_Row(ParallelModule): src_rank = 0 else: src_rank = dist.distributed_c10d._get_global_rank(self.process_group, 0) + + origin_device = self.bias.device + self.bias = self.bias.cuda() dist.broadcast(self.bias, src=src_rank, group=self.process_group) + self.bias = self.bias.to(origin_device) def forward(self, input_: Tensor) -> Tensor: # Set up backprop all-reduce. diff --git a/colossalai/tensor/d_tensor/api.py b/colossalai/tensor/d_tensor/api.py index afb1fc003..b58edadfe 100644 --- a/colossalai/tensor/d_tensor/api.py +++ b/colossalai/tensor/d_tensor/api.py @@ -10,9 +10,21 @@ from .d_tensor import DTensor from .sharding_spec import ShardingSpec -def shard_rowwise(tensor: torch.Tensor, group_or_device_mesh: Union[ProcessGroup, DeviceMesh] = None) -> DTensor: +def shard_rowwise(tensor: torch.Tensor, + group_or_device_mesh: Union[ProcessGroup, DeviceMesh] = None, + inplace: bool = False) -> DTensor: """ - Shard the first dim of the given tensor + Shard the first dim of the given tensor. + + Args: + tensor (torch.Tensor): The tensor to be sharded. + group_or_device_mesh (Union[ProcessGroup, DeviceMesh], optional): The group or device mesh to shard the tensor. + If None, the tensor will be sharded with respect to the global process group. + Defaults to None. + inplace (bool, optional): Whether to shard the tensor in-place. Defaults to False. + + Returns: + DTensor: The sharded tensor. """ # if the group_or_device_mesh is None, we shard the tensor with respect to the global process group if group_or_device_mesh is None: @@ -24,12 +36,28 @@ def shard_rowwise(tensor: torch.Tensor, group_or_device_mesh: Union[ProcessGroup assert len(group_or_device_mesh.shape) == 1, 'Only 1D DeviceMesh is accepted for row-wise sharding.' device_mesh = group_or_device_mesh sharding_spec = ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={0: [0]}) + + if not inplace: + tensor = tensor.detach().clone() + return DTensor(tensor, device_mesh, sharding_spec) -def shard_colwise(tensor: torch.Tensor, group_or_device_mesh: Union[ProcessGroup, DeviceMesh] = None) -> DTensor: +def shard_colwise(tensor: torch.Tensor, + group_or_device_mesh: Union[ProcessGroup, DeviceMesh] = None, + inplace: bool = False) -> DTensor: """ - Shard the first dim of the given tensor + Shard the first dim of the given tensor. + + Args: + tensor (torch.Tensor): The tensor to be sharded. + group_or_device_mesh (Union[ProcessGroup, DeviceMesh], optional): The group or device mesh to shard the tensor. + If None, the tensor will be sharded with respect to the global process group. + Defaults to None. + inplace (bool, optional): Whether to shard the tensor in-place. Defaults to False. + + Returns: + DTensor: The sharded tensor. """ # if the group_or_device_mesh is None, we shard the tensor with respect to the global process group if group_or_device_mesh is None: @@ -41,4 +69,8 @@ def shard_colwise(tensor: torch.Tensor, group_or_device_mesh: Union[ProcessGroup assert len(group_or_device_mesh.shape) == 1, 'Only 1D DeviceMesh is accepted for row-wise sharding.' device_mesh = group_or_device_mesh sharding_spec = ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={-1: [0]}) + + if not inplace: + tensor = tensor.detach().clone() + return DTensor(tensor, device_mesh, sharding_spec)