From 7c9f2ed6dd3dd27a099295ac22be8f3d1508c010 Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Thu, 25 May 2023 13:09:42 +0800 Subject: [PATCH] [dtensor] polish sharding spec docstring (#3838) * [dtensor] polish sharding spec docstring * [dtensor] polish sharding spec example docstring --- colossalai/tensor/d_tensor/sharding_spec.py | 31 +++++++++++---------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/colossalai/tensor/d_tensor/sharding_spec.py b/colossalai/tensor/d_tensor/sharding_spec.py index 2ea0c4db8..b927f6dfb 100644 --- a/colossalai/tensor/d_tensor/sharding_spec.py +++ b/colossalai/tensor/d_tensor/sharding_spec.py @@ -116,21 +116,21 @@ class DimSpec: def dim_diff(self, other): ''' - The difference between two _DimSpec. + The difference between two DimSpec. Argument: - other(_DimSpec): the dim spec to compare with. + other(DimSpec): the dim spec to compare with. Return: difference(int): the difference between two _DimSpec. Example: - dim_spec = _DimSpec([0]) - other_dim_spec = _DimSpec([0, 1]) + ```python + dim_spec = DimSpec([0]) + other_dim_spec = DimSpec([0, 1]) print(dim_spec.difference(other_dim_spec)) - - Output: - 5 + # output: 5 + ``` ''' difference = self.difference_dict[(str(self), str(other))] return difference @@ -142,9 +142,13 @@ class ShardingSpec: [R, R, S0, S1], which means Argument: - dim_partition_dict(Dict[int, List[int]], optional): The key is the dimension of tensor to be sharded, - and the value of the key describe which logical axis will be sharded in that dimension. - sharding_sequence(List[DimSpec], optional): A straight view of ShardingSpec looks like [R, R, S0, S1]. + dim_size (int): The number of dimensions of the tensor to be sharded. + dim_partition_dict (Dict[int, List[int]], optional): The key is the dimension of tensor to be sharded, + and the value of the key describe which logical axis will be sharded in that dimension. Defaults to None. + E.g. {0: [0, 1]} means the first dimension of the tensor will be sharded in logical axis 0 and 1. + sharding_sequence (List[DimSpec], optional): A straight view of ShardingSpec looks like [R, R, S0, S1]. + Generally, users should specify either dim_partition_dict or sharding_sequence. + If both are given, users must ensure that they are consistent with each other. Defaults to None. ''' def __init__(self, @@ -208,6 +212,7 @@ class ShardingSpec: pair of sharding sequence. Example: + ```python dim_partition_dict = {0: [0, 1]} # DistSpec: # shard_sequence: S01,R,R @@ -219,10 +224,8 @@ class ShardingSpec: # device_mesh_shape: (4, 4) sharding_spec_to_compare = ShardingSpec(device_mesh, entire_shape, dim_partition_dict_to_compare) print(sharding_spec.sharding_sequence_difference(sharding_spec_to_compare)) - - Output: - 25 - + # output: 25 + ``` Argument: other(ShardingSpec): The ShardingSpec to compared with.