[nfc] fix typo colossalai/shardformer/ (#5133)

pull/5227/head
digger yu 2024-01-04 16:21:55 +08:00 committed by GitHub
parent 451e9142b8
commit b0b53a171c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 6 additions and 6 deletions

View File

@ -79,9 +79,9 @@ Following are the description `ShardConfig`'s arguments:
- `enable_sequence_overlap`: Whether to turn on sequence overlap, which overlap the computation and communication in sequence parallelism. It can only be used when `enable_sequence_parallelism` is True. Defaults to False.
- `enable_all_optimization`: Whether to turn on all optimization tools including `fused normalizaion`, `flash attention`, `JIT fused operators`, `sequence parallelism` and `sequence overlap`. Defaults to False.
- `enable_all_optimization`: Whether to turn on all optimization tools including `fused normalization`, `flash attention`, `JIT fused operators`, `sequence parallelism` and `sequence overlap`. Defaults to False.
- `extra_kwargs`: A dict to store extra kwargs for ShardFomer.
- `extra_kwargs`: A dict to store extra kwargs for ShardFormer.
### Write your own policy

View File

@ -32,7 +32,7 @@ def set_obj_list_element(obj, attr: str, value):
r"""
Set the element to value of a list object
It used like set_obj_list_element(obj, 'lyaers[0]', new_layer), it will set obj.layers[0] to value
It used like set_obj_list_element(obj, 'layers[0]', new_layer), it will set obj.layers[0] to value
Args:
obj (object): The object to set

View File

@ -22,8 +22,8 @@ class ShardConfig:
enable_flash_attention (bool, optional): Whether to switch on flash attention. Defaults to False.
enable_jit_fused (bool, optional): Whether to switch on JIT fused operators. Defaults to False.
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism, which partitions non-tensor-parallel regions along the sequence dimension. Defaults to False.
enable_sequence_overlap (bool): Whether to turn on sequence overlap, wheich overlap the computation and communication in sequence parallelism. It can only be used when enable_sequence_parallelism is True. Defaults to False.
enable_all_optimization (bool): Whether to turn on all optimization tools including 'fused normalizaion', 'flash attention', 'JIT fused operators', 'sequence parallelism' and 'sequence overlap'. Defaults to False.
enable_sequence_overlap (bool): Whether to turn on sequence overlap, which overlap the computation and communication in sequence parallelism. It can only be used when enable_sequence_parallelism is True. Defaults to False.
enable_all_optimization (bool): Whether to turn on all optimization tools including 'fused normalization', 'flash attention', 'JIT fused operators', 'sequence parallelism' and 'sequence overlap'. Defaults to False.
"""
tensor_parallel_process_group: Optional[ProcessGroup] = None
pipeline_stage_manager: Optional[PipelineStageManager] = None

View File

@ -37,7 +37,7 @@ class ModelSharder(object):
self.policy.set_model(self.model)
self.policy.set_shard_config(self.shard_config)
self._preprocess()
# get shared params before release unheld layers, this avoid misjudgement of shared params (None is None)
# get shared params before release unheld layers, this avoid misjudgment of shared params (None is None)
shared_params = self.policy.get_shared_params()
held_layers = self._release_unheld_layers()
self._replace_module(include=held_layers)