mirror of https://github.com/hpcaitech/ColossalAI
[chore] minor fix after rebase
parent
783aafa327
commit
c27f5d9731
|
@ -39,7 +39,6 @@ from colossalai.tensor.d_tensor.api import is_distributed_tensor
|
||||||
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
|
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
|
||||||
from colossalai.zero.low_level import LowLevelZeroOptimizer
|
from colossalai.zero.low_level import LowLevelZeroOptimizer
|
||||||
from colossalai.zero.low_level.zero_hook import ZeroOpHook, wait_all_gather_handle
|
from colossalai.zero.low_level.zero_hook import ZeroOpHook, wait_all_gather_handle
|
||||||
from colossalai.logging import get_dist_logger
|
|
||||||
|
|
||||||
from .pp_plugin_base import PipelinePluginBase
|
from .pp_plugin_base import PipelinePluginBase
|
||||||
|
|
||||||
|
@ -653,6 +652,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
||||||
model: HybridParallelModule,
|
model: HybridParallelModule,
|
||||||
use_pipeline: bool,
|
use_pipeline: bool,
|
||||||
param_info: OrderedDict,
|
param_info: OrderedDict,
|
||||||
|
pg_to_param_list: Dict[ProcessGroup, List[torch.nn.Parameter]] = None,
|
||||||
initial_scale: int = 2**16, # grad scaler config
|
initial_scale: int = 2**16, # grad scaler config
|
||||||
min_scale: int = 1,
|
min_scale: int = 1,
|
||||||
growth_factor: float = 2.0,
|
growth_factor: float = 2.0,
|
||||||
|
@ -685,6 +685,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
initial_scale=initial_scale,
|
initial_scale=initial_scale,
|
||||||
min_scale=min_scale,
|
min_scale=min_scale,
|
||||||
|
pg_to_param_list=pg_to_param_list,
|
||||||
growth_factor=growth_factor,
|
growth_factor=growth_factor,
|
||||||
backoff_factor=backoff_factor,
|
backoff_factor=backoff_factor,
|
||||||
growth_interval=growth_interval,
|
growth_interval=growth_interval,
|
||||||
|
@ -1124,7 +1125,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||||
|
|
||||||
self.logger.info(
|
self.logger.info(
|
||||||
f"{type(self).__name__}: dp_group {dist.get_process_group_ranks(self.dp_group)} pp_group {dist.get_process_group_ranks(self.pp_group)} tp_group {dist.get_process_group_ranks(self.tp_group)} sp_group {dist.get_process_group_ranks(self.sp_group)}",
|
f"{type(self).__name__}: dp_group {dist.get_process_group_ranks(self.dp_group)} pp_group {dist.get_process_group_ranks(self.pp_group)} tp_group {dist.get_process_group_ranks(self.tp_group)} sp_group {dist.get_process_group_ranks(self.sp_group)}",
|
||||||
ranks=[0, 1, 2, 3, 4, 5, 6, 7],
|
ranks=[0],
|
||||||
)
|
)
|
||||||
self.shard_config = ShardConfig(
|
self.shard_config = ShardConfig(
|
||||||
tensor_parallel_process_group=self.tp_group,
|
tensor_parallel_process_group=self.tp_group,
|
||||||
|
|
|
@ -55,6 +55,7 @@ class MoeHybridParallelZeroOptimizer(HybridParallelZeroOptimizer):
|
||||||
partition_grad: bool = False, # stage 2 flag
|
partition_grad: bool = False, # stage 2 flag
|
||||||
cpu_offload: bool = False, # cpu offload
|
cpu_offload: bool = False, # cpu offload
|
||||||
forced_dtype: Optional[torch.dtype] = None,
|
forced_dtype: Optional[torch.dtype] = None,
|
||||||
|
overlap_allgather: bool = False,
|
||||||
):
|
):
|
||||||
WARN_STR = "Note that you need to make sure every expert are routed (i.e.) every expert has backward, otherwise this might lead to program hang or inconsistent result"
|
WARN_STR = "Note that you need to make sure every expert are routed (i.e.) every expert has backward, otherwise this might lead to program hang or inconsistent result"
|
||||||
if not force_overlap_comm and (overlap_communication or partition_grad):
|
if not force_overlap_comm and (overlap_communication or partition_grad):
|
||||||
|
@ -95,6 +96,7 @@ class MoeHybridParallelZeroOptimizer(HybridParallelZeroOptimizer):
|
||||||
pp_process_group=pp_process_group,
|
pp_process_group=pp_process_group,
|
||||||
forced_dtype=forced_dtype,
|
forced_dtype=forced_dtype,
|
||||||
pg_to_param_list=pg_param_list,
|
pg_to_param_list=pg_param_list,
|
||||||
|
overlap_allgather=overlap_allgather,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -47,7 +47,6 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
|
||||||
dtype, precision = torch.float16, "fp16"
|
dtype, precision = torch.float16, "fp16"
|
||||||
torch.cuda.set_device(dist.get_rank())
|
torch.cuda.set_device(dist.get_rank())
|
||||||
|
|
||||||
print(config)
|
|
||||||
plugin = MoeHybridParallelPlugin(
|
plugin = MoeHybridParallelPlugin(
|
||||||
pp_size=pp_size,
|
pp_size=pp_size,
|
||||||
num_microbatches=pp_size,
|
num_microbatches=pp_size,
|
||||||
|
|
Loading…
Reference in New Issue