[chore] minor fix after rebase

moe_sp
hxwang 2024-07-19 07:53:40 +00:00
parent 783aafa327
commit c27f5d9731
No known key found for this signature in database
GPG Key ID: 0EC383D418F0B9F8
3 changed files with 5 additions and 3 deletions

View File

@ -39,7 +39,6 @@ from colossalai.tensor.d_tensor.api import is_distributed_tensor
from colossalai.tensor.param_op_hook import ColoParamOpHookManager from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.zero.low_level import LowLevelZeroOptimizer from colossalai.zero.low_level import LowLevelZeroOptimizer
from colossalai.zero.low_level.zero_hook import ZeroOpHook, wait_all_gather_handle from colossalai.zero.low_level.zero_hook import ZeroOpHook, wait_all_gather_handle
from colossalai.logging import get_dist_logger
from .pp_plugin_base import PipelinePluginBase from .pp_plugin_base import PipelinePluginBase
@ -653,6 +652,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
model: HybridParallelModule, model: HybridParallelModule,
use_pipeline: bool, use_pipeline: bool,
param_info: OrderedDict, param_info: OrderedDict,
pg_to_param_list: Dict[ProcessGroup, List[torch.nn.Parameter]] = None,
initial_scale: int = 2**16, # grad scaler config initial_scale: int = 2**16, # grad scaler config
min_scale: int = 1, min_scale: int = 1,
growth_factor: float = 2.0, growth_factor: float = 2.0,
@ -685,6 +685,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
optimizer=optimizer, optimizer=optimizer,
initial_scale=initial_scale, initial_scale=initial_scale,
min_scale=min_scale, min_scale=min_scale,
pg_to_param_list=pg_to_param_list,
growth_factor=growth_factor, growth_factor=growth_factor,
backoff_factor=backoff_factor, backoff_factor=backoff_factor,
growth_interval=growth_interval, growth_interval=growth_interval,
@ -1124,7 +1125,7 @@ class HybridParallelPlugin(PipelinePluginBase):
self.logger.info( self.logger.info(
f"{type(self).__name__}: dp_group {dist.get_process_group_ranks(self.dp_group)} pp_group {dist.get_process_group_ranks(self.pp_group)} tp_group {dist.get_process_group_ranks(self.tp_group)} sp_group {dist.get_process_group_ranks(self.sp_group)}", f"{type(self).__name__}: dp_group {dist.get_process_group_ranks(self.dp_group)} pp_group {dist.get_process_group_ranks(self.pp_group)} tp_group {dist.get_process_group_ranks(self.tp_group)} sp_group {dist.get_process_group_ranks(self.sp_group)}",
ranks=[0, 1, 2, 3, 4, 5, 6, 7], ranks=[0],
) )
self.shard_config = ShardConfig( self.shard_config = ShardConfig(
tensor_parallel_process_group=self.tp_group, tensor_parallel_process_group=self.tp_group,

View File

@ -55,6 +55,7 @@ class MoeHybridParallelZeroOptimizer(HybridParallelZeroOptimizer):
partition_grad: bool = False, # stage 2 flag partition_grad: bool = False, # stage 2 flag
cpu_offload: bool = False, # cpu offload cpu_offload: bool = False, # cpu offload
forced_dtype: Optional[torch.dtype] = None, forced_dtype: Optional[torch.dtype] = None,
overlap_allgather: bool = False,
): ):
WARN_STR = "Note that you need to make sure every expert are routed (i.e.) every expert has backward, otherwise this might lead to program hang or inconsistent result" WARN_STR = "Note that you need to make sure every expert are routed (i.e.) every expert has backward, otherwise this might lead to program hang or inconsistent result"
if not force_overlap_comm and (overlap_communication or partition_grad): if not force_overlap_comm and (overlap_communication or partition_grad):
@ -95,6 +96,7 @@ class MoeHybridParallelZeroOptimizer(HybridParallelZeroOptimizer):
pp_process_group=pp_process_group, pp_process_group=pp_process_group,
forced_dtype=forced_dtype, forced_dtype=forced_dtype,
pg_to_param_list=pg_param_list, pg_to_param_list=pg_param_list,
overlap_allgather=overlap_allgather,
) )

View File

@ -47,7 +47,6 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
dtype, precision = torch.float16, "fp16" dtype, precision = torch.float16, "fp16"
torch.cuda.set_device(dist.get_rank()) torch.cuda.set_device(dist.get_rank())
print(config)
plugin = MoeHybridParallelPlugin( plugin = MoeHybridParallelPlugin(
pp_size=pp_size, pp_size=pp_size,
num_microbatches=pp_size, num_microbatches=pp_size,