mirror of https://github.com/hpcaitech/ColossalAI
[shardformer] removed inplace tensor sharding (#4018)
parent
3893fa1a8d
commit
45d9384346
|
@ -329,7 +329,11 @@ class Linear1D_Row(ParallelModule):
|
||||||
src_rank = 0
|
src_rank = 0
|
||||||
else:
|
else:
|
||||||
src_rank = dist.distributed_c10d._get_global_rank(self.process_group, 0)
|
src_rank = dist.distributed_c10d._get_global_rank(self.process_group, 0)
|
||||||
|
|
||||||
|
origin_device = self.bias.device
|
||||||
|
self.bias = self.bias.cuda()
|
||||||
dist.broadcast(self.bias, src=src_rank, group=self.process_group)
|
dist.broadcast(self.bias, src=src_rank, group=self.process_group)
|
||||||
|
self.bias = self.bias.to(origin_device)
|
||||||
|
|
||||||
def forward(self, input_: Tensor) -> Tensor:
|
def forward(self, input_: Tensor) -> Tensor:
|
||||||
# Set up backprop all-reduce.
|
# Set up backprop all-reduce.
|
||||||
|
|
|
@ -10,9 +10,21 @@ from .d_tensor import DTensor
|
||||||
from .sharding_spec import ShardingSpec
|
from .sharding_spec import ShardingSpec
|
||||||
|
|
||||||
|
|
||||||
def shard_rowwise(tensor: torch.Tensor, group_or_device_mesh: Union[ProcessGroup, DeviceMesh] = None) -> DTensor:
|
def shard_rowwise(tensor: torch.Tensor,
|
||||||
|
group_or_device_mesh: Union[ProcessGroup, DeviceMesh] = None,
|
||||||
|
inplace: bool = False) -> DTensor:
|
||||||
"""
|
"""
|
||||||
Shard the first dim of the given tensor
|
Shard the first dim of the given tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor (torch.Tensor): The tensor to be sharded.
|
||||||
|
group_or_device_mesh (Union[ProcessGroup, DeviceMesh], optional): The group or device mesh to shard the tensor.
|
||||||
|
If None, the tensor will be sharded with respect to the global process group.
|
||||||
|
Defaults to None.
|
||||||
|
inplace (bool, optional): Whether to shard the tensor in-place. Defaults to False.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DTensor: The sharded tensor.
|
||||||
"""
|
"""
|
||||||
# if the group_or_device_mesh is None, we shard the tensor with respect to the global process group
|
# if the group_or_device_mesh is None, we shard the tensor with respect to the global process group
|
||||||
if group_or_device_mesh is None:
|
if group_or_device_mesh is None:
|
||||||
|
@ -24,12 +36,28 @@ def shard_rowwise(tensor: torch.Tensor, group_or_device_mesh: Union[ProcessGroup
|
||||||
assert len(group_or_device_mesh.shape) == 1, 'Only 1D DeviceMesh is accepted for row-wise sharding.'
|
assert len(group_or_device_mesh.shape) == 1, 'Only 1D DeviceMesh is accepted for row-wise sharding.'
|
||||||
device_mesh = group_or_device_mesh
|
device_mesh = group_or_device_mesh
|
||||||
sharding_spec = ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={0: [0]})
|
sharding_spec = ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={0: [0]})
|
||||||
|
|
||||||
|
if not inplace:
|
||||||
|
tensor = tensor.detach().clone()
|
||||||
|
|
||||||
return DTensor(tensor, device_mesh, sharding_spec)
|
return DTensor(tensor, device_mesh, sharding_spec)
|
||||||
|
|
||||||
|
|
||||||
def shard_colwise(tensor: torch.Tensor, group_or_device_mesh: Union[ProcessGroup, DeviceMesh] = None) -> DTensor:
|
def shard_colwise(tensor: torch.Tensor,
|
||||||
|
group_or_device_mesh: Union[ProcessGroup, DeviceMesh] = None,
|
||||||
|
inplace: bool = False) -> DTensor:
|
||||||
"""
|
"""
|
||||||
Shard the first dim of the given tensor
|
Shard the first dim of the given tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor (torch.Tensor): The tensor to be sharded.
|
||||||
|
group_or_device_mesh (Union[ProcessGroup, DeviceMesh], optional): The group or device mesh to shard the tensor.
|
||||||
|
If None, the tensor will be sharded with respect to the global process group.
|
||||||
|
Defaults to None.
|
||||||
|
inplace (bool, optional): Whether to shard the tensor in-place. Defaults to False.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DTensor: The sharded tensor.
|
||||||
"""
|
"""
|
||||||
# if the group_or_device_mesh is None, we shard the tensor with respect to the global process group
|
# if the group_or_device_mesh is None, we shard the tensor with respect to the global process group
|
||||||
if group_or_device_mesh is None:
|
if group_or_device_mesh is None:
|
||||||
|
@ -41,4 +69,8 @@ def shard_colwise(tensor: torch.Tensor, group_or_device_mesh: Union[ProcessGroup
|
||||||
assert len(group_or_device_mesh.shape) == 1, 'Only 1D DeviceMesh is accepted for row-wise sharding.'
|
assert len(group_or_device_mesh.shape) == 1, 'Only 1D DeviceMesh is accepted for row-wise sharding.'
|
||||||
device_mesh = group_or_device_mesh
|
device_mesh = group_or_device_mesh
|
||||||
sharding_spec = ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={-1: [0]})
|
sharding_spec = ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={-1: [0]})
|
||||||
|
|
||||||
|
if not inplace:
|
||||||
|
tensor = tensor.detach().clone()
|
||||||
|
|
||||||
return DTensor(tensor, device_mesh, sharding_spec)
|
return DTensor(tensor, device_mesh, sharding_spec)
|
||||||
|
|
Loading…
Reference in New Issue