From d35bd7d0e64f10161ab4f6abc8776f14d19bba38 Mon Sep 17 00:00:00 2001 From: ver217 Date: Wed, 5 Jul 2023 15:20:59 +0800 Subject: [PATCH] [shardformer] fix type hint --- colossalai/shardformer/shard/shard_config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index fba2c27a2..75fad4eb7 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -15,8 +15,8 @@ class ShardConfig: The config for sharding the huggingface model Args: - tensor_parallel_process_group (int): The process group for tensor parallelism, defaults to None, which is the global process group. - pipeline_stage_manager (PipelineStageManager): The pipeline stage manager, defaults to None, which means no pipeline. + tensor_parallel_process_group (Optional[ProcessGroup]): The process group for tensor parallelism, defaults to None, which is the global process group. + pipeline_stage_manager (Optional[PipelineStageManager]): The pipeline stage manager, defaults to None, which means no pipeline. enable_tensor_parallelism (bool): Whether to turn on tensor parallelism, default is True. enable_fused_normalization (bool): Whether to use fused layernorm, default is False. enable_all_optimization (bool): Whether to turn on all optimization, default is False.