[hotfix]fix argument naming in docs and examples (#4083)

pull/4105/head
Baizhou Zhang 2023-06-26 23:50:04 +08:00 committed by GitHub
parent e89b127d8e
commit 4da324cd60
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 40 additions and 41 deletions

View File

@ -34,9 +34,9 @@ class ColossalAIStrategy(DDPStrategy):
If it is cuda, they will not be offloaded, which means max CUDA memory will be used. It is the fastest.
pin_memory(bool): Whether to pin the memory for the data loader. Only for ZeRO-3.
force_outputs_fp32(bool): Whether to force the outputs to be fp32. Only for ZeRO-3.
search_range_mb(int): The search range in MB for the chunk size. Only for ZeRO-3.
search_range_m(int): The number of search range for the chunk size, divided by 2^20. Only for ZeRO-3.
hidden_dim(optional, int): The hidden dimension for the gemini. Only for ZeRO-3.
min_chunk_size_mb(float): The minimum chunk size in MB. Only for ZeRO-3.
min_chunk_size_m(float): The minimum chunk size divided by 2^20. Only for ZeRO-3.
gpu_margin_mem_ratio(float): The margin memory ratio for the GPU. Only for ZeRO-3.
reduce_bucket_size(int): The reduce bucket size in bytes. Only for ZeRO-1 and ZeRO-2.
overlap_communication(bool): Whether to overlap communication and computation. Only for ZeRO-1 and ZeRO-2.
@ -61,9 +61,9 @@ class ColossalAIStrategy(DDPStrategy):
placement_policy: str = 'cuda',
pin_memory: bool = True, # only for stage 3
force_outputs_fp32: bool = False, # only for stage 3
search_range_mb: int = 32, # only for stage 3
search_range_m: int = 32, # only for stage 3
hidden_dim: Optional[int] = None, # only for stage 3
min_chunk_size_mb: float = 32, # only for stage 3
min_chunk_size_m: float = 32, # only for stage 3
gpu_margin_mem_ratio: float = 0.0, # only for stage 3
reduce_bucket_size: int = 12 * 1024**2, # only for stage 1&2
overlap_communication: bool = True, # only for stage 1&2
@ -83,57 +83,51 @@ class ColossalAIStrategy(DDPStrategy):
# TODO(ver217): support shard_init when using from_pretrained()
if shard_init:
warnings.warn(
f'Shard init is not supported model.from_pretrained() yet. '
'Please load weights after strategy.prepare()'
)
warnings.warn(f'Shard init is not supported model.from_pretrained() yet. '
'Please load weights after strategy.prepare()')
if stage == 3 and precision == 'fp32':
warnings.warn(f'Stage 3 only supports fp16. Precision is set to fp16.')
precision = 'fp16'
self.precision = precision
self.shard_init = shard_init
optim_kwargs = dict(
initial_scale=initial_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
min_scale=min_scale,
max_scale=max_scale,
max_norm=max_norm,
norm_type=norm_type
)
optim_kwargs = dict(initial_scale=initial_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
min_scale=min_scale,
max_scale=max_scale,
max_norm=max_norm,
norm_type=norm_type)
# NOTE: dist should be initialized before calling get_current_device()
if stage == 3:
plugin_initializer = lambda: GeminiPlugin(
# gemini_config
# gemini_config
device=get_current_device(),
placement_policy=placement_policy,
precision=precision,
pin_memory=pin_memory,
force_outputs_fp32=force_outputs_fp32,
strict_ddp_mode=shard_init,
search_range_mb=search_range_mb,
search_range_m=search_range_m,
hidden_dim=hidden_dim,
min_chunk_size_mb=min_chunk_size_mb,
# zero_optim_config
min_chunk_size_m=min_chunk_size_m,
# zero_optim_config
gpu_margin_mem_ratio=gpu_margin_mem_ratio,
# optim_config
**optim_kwargs
)
# optim_config
**optim_kwargs)
else:
plugin_initializer = lambda: LowLevelZeroPlugin(
# zero_config
# zero_config
stage=stage,
precision=precision,
# zero_optim_config
# zero_optim_config
reduce_bucket_size_in_m=reduce_bucket_size,
overlap_communication=overlap_communication,
cpu_offload=(placement_policy == 'cpu'),
# optim_config
**optim_kwargs
)
# optim_config
**optim_kwargs)
super().__init__(seed, plugin_initializer)

View File

@ -181,7 +181,7 @@ def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placement_policy:
device=get_current_device(),
placement_policy=placement_policy,
pin_memory=True,
search_range_mb=32)
search_range_m=32)
return model
```
@ -190,3 +190,5 @@ def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placement_policy:
The above optimization we made allows us to pretrain the GPT-2 model on a single GPU. We only need to set the parameter `GPUNUM`=1 in `run.sh`, and then we can complete the model training on a single GPU when running the file.
The GPT-2 example is accessible at [Train GPT with Colossal-AI](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/gpt).
<!-- doc-test-command: torchrun --standalone --nproc_per_node=1 parallelize_your_training_like_Megatron.py -->

View File

@ -67,12 +67,12 @@ Define the model parameters as follows:
chunk_manager = init_chunk_manager(model=module,
init_device=device,
hidden_dim=hidden_dim,
search_range_mb=search_range_mb,
min_chunk_size_mb=min_chunk_size_mb)
search_range_m=search_range_m,
min_chunk_size_m=min_chunk_size_m)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
```
`hidden_dim` is the hidden dimension of DNN. Users can provide this argument to speed up searching. If users do not know this argument before training, it is ok. We will use a default value 1024. `min_chunk_size_mb` is the the minimum chunk size in MegaByte. If the aggregate size of parameters is still smaller than the minimum chunk size, all parameters will be compacted into one small chunk.
`hidden_dim` is the hidden dimension of DNN. Users can provide this argument to speed up searching. If users do not know this argument before training, it is ok. We will use a default value 1024. `min_chunk_size_m` is a floating point, being the minimum chunk size divided by 2^20 (e.g., if min_chunk_size_m=2.5, then the minimum chunk size should be 2.5*(2^20)).If the aggregate size of parameters is still smaller than the minimum chunk size, all parameters will be compacted into one small chunk.
Initialization of the optimizer.
```python

View File

@ -165,7 +165,7 @@ def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placement_policy:
device=get_current_device(),
placement_policy=placement_policy,
pin_memory=True,
search_range_mb=32)
search_range_m=32)
return model
```
@ -174,3 +174,6 @@ def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placement_policy:
我们做的上述优化让我们可以在单GPU上训练GPT-2模型只需要将`run.sh`中设置参数`GPUNUM`=1再运行文件时就可以在单个GPU上完成模型的训练。
GPT-2 示例在[Train GPT with Colossal-AI](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/gpt). 获得。
<!-- doc-test-command: torchrun --standalone --nproc_per_node=1 parallelize_your_training_like_Megatron.py -->

View File

@ -66,13 +66,13 @@ with ColoInitContext(device='cpu', default_dist_spec=default_dist_spec, default_
chunk_manager = init_chunk_manager(model=module,
init_device=device,
hidden_dim=hidden_dim,
search_range_mb=search_range_mb,
min_chunk_size_mb=min_chunk_size_mb)
search_range_m=search_range_m,
min_chunk_size_m=min_chunk_size_m)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager)
```
`hidden dim`是DNN的隐藏维度。用户可以提供这个参数来加快搜索速度。如果用户在训练前不知道这个参数也可以。 我们将使用默认值 1024。`min_chunk_size_mb`是以兆字节为单位的最小块大小。如果参数的总大小仍然小于最小块大小,则所有参数将被压缩为一个小块。
`hidden dim`是DNN的隐藏维度。用户可以提供这个参数来加快搜索速度。如果用户在训练前不知道这个参数也可以。 我们将使用默认值 1024。`min_chunk_size_m`是以兆2^20为单位的最小块大小。如果参数的总大小仍然小于最小块大小,则所有参数将被压缩为一个小块。
初始化优化器。
```python

View File

@ -88,7 +88,7 @@ def main():
placement_policy=args.placement,
pin_memory=True,
hidden_dim=model.config.hidden_size,
search_range_mb=128)
search_range_m=128)
optim_config = dict(gpu_margin_mem_ratio=0.)
else:
raise RuntimeError

View File

View File

@ -258,7 +258,7 @@ def main():
placement_policy=args.placement,
pin_memory=True,
strict_ddp_mode=args.tp_degree == 1,
search_range_mb=128,
search_range_m=128,
hidden_dim=model.config.n_embd,
gpu_margin_mem_ratio=0.)
else: