mirror of https://github.com/InternLM/InternLM
fix(conflicts): merge main to develop
commit
7b7b23ed89
|
@ -6,11 +6,14 @@ The system code file structure is shown below:
|
||||||
├── internlm # Main directory of the system code
|
├── internlm # Main directory of the system code
|
||||||
│ ├── apis # Interface module, containing some interface functions related to inference, etc.
|
│ ├── apis # Interface module, containing some interface functions related to inference, etc.
|
||||||
│ ├── core # Core module, managing parallel context and training scheduling engine for training and inference
|
│ ├── core # Core module, managing parallel context and training scheduling engine for training and inference
|
||||||
|
│ │ ├── communication # Communication module, responsible for p2p communication in pipeline parallel scheduling
|
||||||
│ │ ├── context # Context module, mainly responsible for initializing parallel process groups and managing parallel context
|
│ │ ├── context # Context module, mainly responsible for initializing parallel process groups and managing parallel context
|
||||||
│ │ │ ├── parallel_context.py
|
│ │ │ ├── parallel_context.py
|
||||||
│ │ │ └── process_group_initializer.py
|
│ │ │ └── process_group_initializer.py
|
||||||
|
│ │ ├── scheduler # Scheduling module, which manages schedulers for parallel training, including non-pipeline and pipeline parallel schedulers
|
||||||
|
│ │ │ ├── no_pipeline_scheduler.py
|
||||||
|
│ │ │ └── pipeline_scheduler.py
|
||||||
│ │ ├── engine.py # Responsible for managing the training and evaluation process of the model
|
│ │ ├── engine.py # Responsible for managing the training and evaluation process of the model
|
||||||
│ │ ├── no_pipeline_scheduler.py # Scheduler for parallel training
|
|
||||||
│ │ └── trainer.py # Responsible for managing the training engine and scheduler
|
│ │ └── trainer.py # Responsible for managing the training engine and scheduler
|
||||||
│ ├── data # Data module, responsible for managing dataset generation and processing
|
│ ├── data # Data module, responsible for managing dataset generation and processing
|
||||||
│ ├── initialize # Initialization module, responsible for managing distributed environment startup and trainer initialization
|
│ ├── initialize # Initialization module, responsible for managing distributed environment startup and trainer initialization
|
||||||
|
|
|
@ -165,8 +165,9 @@ Training parallel configuration example:
|
||||||
```python
|
```python
|
||||||
parallel = dict(
|
parallel = dict(
|
||||||
zero1=8,
|
zero1=8,
|
||||||
pipeline=1,
|
|
||||||
tensor=1,
|
tensor=1,
|
||||||
|
pipeline=dict(size=1, interleaved_overlap=True),
|
||||||
|
sequence_parallel=False,
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -174,8 +175,11 @@ parallel = dict(
|
||||||
- When `size <= 0`, the size of the zero1 process group is equal to the size of the data parallel process group, so the optimizer state parameters will be split within the data parallel range.
|
- When `size <= 0`, the size of the zero1 process group is equal to the size of the data parallel process group, so the optimizer state parameters will be split within the data parallel range.
|
||||||
- When `size == 1`, zero1 is not used, and all data parallel groups retain the complete optimizer state parameters.
|
- When `size == 1`, zero1 is not used, and all data parallel groups retain the complete optimizer state parameters.
|
||||||
- When `size > 1` and `size <= data_parallel_world_size`, the zero1 process group is a subset of the data parallel process group.
|
- When `size > 1` and `size <= data_parallel_world_size`, the zero1 process group is a subset of the data parallel process group.
|
||||||
- pipeline: pipeline parallel size, default value is 1
|
- tensor: tensor parallel size, usually the number of GPUs per node, default is 1
|
||||||
- tensor: tensor parallel size, usually the number of GPUs per node, default value is 1
|
- pipeline: pipeline parallel strategy
|
||||||
|
- size: pipeline parallel size, the default value is 1
|
||||||
|
- interleaved_overlap: bool type, when interleaved scheduling, enable or disable communication optimization, the default value is False
|
||||||
|
- sequence_parallel: Whether to enable sequence parallelism, the default value is False
|
||||||
|
|
||||||
Note: `Data parallel size = Total number of GPUs / Pipeline parallel size / Tensor parallel size`
|
Note: `Data parallel size = Total number of GPUs / Pipeline parallel size / Tensor parallel size`
|
||||||
|
|
||||||
|
|
|
@ -6,11 +6,14 @@
|
||||||
├── internlm # 系统代码的主目录
|
├── internlm # 系统代码的主目录
|
||||||
│ ├── apis # 接口模块,包含一些关于推理等的接口函数
|
│ ├── apis # 接口模块,包含一些关于推理等的接口函数
|
||||||
│ ├── core # 核心模块,管理用于训练和推理的 parallel context 和训练调度引擎
|
│ ├── core # 核心模块,管理用于训练和推理的 parallel context 和训练调度引擎
|
||||||
|
│ │ ├── communication # 通信模块,负责流水线并行调度中的p2p通信
|
||||||
│ │ ├── context # context 模块,主要负责初始化并行进程组,并管理 parallel context
|
│ │ ├── context # context 模块,主要负责初始化并行进程组,并管理 parallel context
|
||||||
│ │ │ ├── parallel_context.py
|
│ │ │ ├── parallel_context.py
|
||||||
│ │ │ └── process_group_initializer.py
|
│ │ │ └── process_group_initializer.py
|
||||||
|
│ │ ├── scheduler # 调度模块,管理并行训练的调度器,包括非流水线并行调度器和流水线并行调度器
|
||||||
|
│ │ │ ├── no_pipeline_scheduler.py
|
||||||
|
│ │ │ └── pipeline_scheduler.py
|
||||||
│ │ ├── engine.py # 负责管理模型的训练和评估过程
|
│ │ ├── engine.py # 负责管理模型的训练和评估过程
|
||||||
│ │ ├── no_pipeline_scheduler.py # 并行训练的调度器
|
|
||||||
│ │ └── trainer.py # 负责管理训练引擎和调度器
|
│ │ └── trainer.py # 负责管理训练引擎和调度器
|
||||||
│ ├── data # 数据模块,负责管理数据集生成和处理
|
│ ├── data # 数据模块,负责管理数据集生成和处理
|
||||||
│ ├── initialize # 初始化模块,负责管理分布式环境启动和训练器初始化
|
│ ├── initialize # 初始化模块,负责管理分布式环境启动和训练器初始化
|
||||||
|
|
|
@ -151,16 +151,20 @@ model = dict(
|
||||||
```python
|
```python
|
||||||
parallel = dict(
|
parallel = dict(
|
||||||
zero1=8,
|
zero1=8,
|
||||||
pipeline=1,
|
|
||||||
tensor=1,
|
tensor=1,
|
||||||
|
pipeline=dict(size=1, interleaved_overlap=True),
|
||||||
|
sequence_parallel=False,
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
- zero1:zero 并行策略,分如下三种情况,默认值为 -1
|
- zero1:zero 并行策略,分如下三种情况,默认值为 -1
|
||||||
- 当`size <= 0`,则 zero1 进程组的大小等于数据并行进程组的大小,因此优化器状态参数将在数据并行范围内分配
|
- 当`size <= 0`,则 zero1 进程组的大小等于数据并行进程组的大小,因此优化器状态参数将在数据并行范围内分配
|
||||||
- 当`size == 1`,则不使用 zero1 ,所有数据并行组保留完整的优化器状态参数
|
- 当`size == 1`,则不使用 zero1 ,所有数据并行组保留完整的优化器状态参数
|
||||||
- 当`size > 1`且`size <= data_parallel_world_size`,则 zero1 进程组是数据并行进程组的子集
|
- 当`size > 1`且`size <= data_parallel_world_size`,则 zero1 进程组是数据并行进程组的子集
|
||||||
- pipeline:流水线并行大小,默认值为 1
|
|
||||||
- tensor:张量并行大小,通常是每个节点的 GPU 数量,默认值为 1
|
- tensor:张量并行大小,通常是每个节点的 GPU 数量,默认值为 1
|
||||||
|
- pipeline:流水线并行策略
|
||||||
|
- size:流水线并行大小,默认值为 1
|
||||||
|
- interleaved_overlap:bool 类型,交错式调度时,开启或关闭通信优化,默认值为关闭
|
||||||
|
- sequence_parallel:是否开启序列化并行,默认值为 False
|
||||||
|
|
||||||
注意:`数据并行大小 = 总的 GPU 数目 / 流水线并行大小 / 张量并行大小`
|
注意:`数据并行大小 = 总的 GPU 数目 / 流水线并行大小 / 张量并行大小`
|
||||||
|
|
||||||
|
|
|
@ -137,15 +137,13 @@ class RotaryEmbedding(torch.nn.Module):
|
||||||
""" """
|
""" """
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# Generate and save the inverse frequency buffer (non trainable)
|
# Generate and save the inverse frequency buffer (non trainable)
|
||||||
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim))
|
self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim))
|
||||||
self.register_buffer("inv_freq", inv_freq)
|
|
||||||
self.scale_base = scale_base
|
self.scale_base = scale_base
|
||||||
scale = (
|
self.scale = (
|
||||||
(torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
|
(torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
|
||||||
if scale_base > 0
|
if scale_base > 0
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
self.register_buffer("scale", scale)
|
|
||||||
|
|
||||||
self._seq_len_cached = 0
|
self._seq_len_cached = 0
|
||||||
self._cos_cached = None
|
self._cos_cached = None
|
||||||
|
@ -220,3 +218,15 @@ class RotaryEmbedding(torch.nn.Module):
|
||||||
self._cos_k_cached[seqlen_offset:],
|
self._cos_k_cached[seqlen_offset:],
|
||||||
self._sin_k_cached[seqlen_offset:],
|
self._sin_k_cached[seqlen_offset:],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _single_forward(self, x, indexes=0):
|
||||||
|
assert self.scale is None
|
||||||
|
self._update_cos_sin_cache(x, indexes)
|
||||||
|
x = x[None, ...]
|
||||||
|
ret = legacy_apply_rotary_embed(x, self._cos_cached[indexes], self._sin_cached[indexes]).squeeze(0)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def _single_eval_forward(self, x, seqlen_offset=0):
|
||||||
|
assert self.scale is None
|
||||||
|
self._update_cos_sin_cache(x, seqlen_offset + x.shape[1])
|
||||||
|
return legacy_apply_rotary_embed(x, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:])
|
||||||
|
|
Loading…
Reference in New Issue