Browse Source

[dtensor] polish sharding spec docstring (#3838)

* [dtensor] polish sharding spec docstring

* [dtensor] polish sharding spec example docstring
pull/3926/head
Hongxin Liu 2 years ago committed by GitHub
parent
commit
7c9f2ed6dd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 27
      colossalai/tensor/d_tensor/sharding_spec.py

27
colossalai/tensor/d_tensor/sharding_spec.py

@ -116,21 +116,21 @@ class DimSpec:
def dim_diff(self, other):
'''
The difference between two _DimSpec.
The difference between two DimSpec.
Argument:
other(_DimSpec): the dim spec to compare with.
other(DimSpec): the dim spec to compare with.
Return:
difference(int): the difference between two _DimSpec.
Example:
dim_spec = _DimSpec([0])
other_dim_spec = _DimSpec([0, 1])
```python
dim_spec = DimSpec([0])
other_dim_spec = DimSpec([0, 1])
print(dim_spec.difference(other_dim_spec))
Output:
5
# output: 5
```
'''
difference = self.difference_dict[(str(self), str(other))]
return difference
@ -142,9 +142,13 @@ class ShardingSpec:
[R, R, S0, S1], which means
Argument:
dim_size (int): The number of dimensions of the tensor to be sharded.
dim_partition_dict (Dict[int, List[int]], optional): The key is the dimension of tensor to be sharded,
and the value of the key describe which logical axis will be sharded in that dimension.
and the value of the key describe which logical axis will be sharded in that dimension. Defaults to None.
E.g. {0: [0, 1]} means the first dimension of the tensor will be sharded in logical axis 0 and 1.
sharding_sequence (List[DimSpec], optional): A straight view of ShardingSpec looks like [R, R, S0, S1].
Generally, users should specify either dim_partition_dict or sharding_sequence.
If both are given, users must ensure that they are consistent with each other. Defaults to None.
'''
def __init__(self,
@ -208,6 +212,7 @@ class ShardingSpec:
pair of sharding sequence.
Example:
```python
dim_partition_dict = {0: [0, 1]}
# DistSpec:
# shard_sequence: S01,R,R
@ -219,10 +224,8 @@ class ShardingSpec:
# device_mesh_shape: (4, 4)
sharding_spec_to_compare = ShardingSpec(device_mesh, entire_shape, dim_partition_dict_to_compare)
print(sharding_spec.sharding_sequence_difference(sharding_spec_to_compare))
Output:
25
# output: 25
```
Argument:
other(ShardingSpec): The ShardingSpec to compared with.

Loading…
Cancel
Save