Update mixed_precision.rst

pull/319/head
Wenwen Qu 2023-09-26 16:45:11 +08:00 committed by GitHub
parent 344f543c4c
commit 9d0c41e85b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 3 additions and 3 deletions

View File

@ -1,17 +1,17 @@
混合精度 混合精度
----------------- -----------------
混合精度是指在模型训练的过程中同时使用16位和32位浮点类型是一种在最小化精度损失的前提下加速模型训练的方法。 混合精度是指在模型训练的过程中同时使用16位和32位浮点类型,是一种在最小化精度损失的前提下加速模型训练的方法。
混合精度通过让模型的某些部分使用32位浮点数以保持数值稳定性并在其余部分利用半精度浮点数加速训练并可以减少内存使用在评估指标如准确率方面仍可以获得同等的训练效果。 混合精度通过让模型的某些部分使用32位浮点数以保持数值稳定性并在其余部分利用半精度浮点数加速训练并可以减少内存使用在评估指标如准确率方面仍可以获得同等的训练效果。
.. autoclass:: internlm.core.naive_amp.NaiveAMPModel .. autoclass:: internlm.core.naive_amp.NaiveAMPModel
InternLM默认将模型转换为16位精度进行训练(在配置文件中可以设置默认类型为其他数据类型)。在使用混合精度时,需要在构建模型时使用 InternLM默认将模型转换为16位浮点数类型进行训练(在配置文件中可以设置默认类型为其他数据类型)。在使用混合精度时,需要在构建模型时使用
.. code-block:: python .. code-block:: python
set_fp32_attr_to_module(/*fp32 module*/) set_fp32_attr_to_module(/*fp32 module*/)
将模型的某个子模块设置为32精度进行训练InternLM会在模型训练时自动将数据类型转换成需要的精度。 将模型的某个子模块设置为32位浮点数类型进行训练InternLM会在模型训练时自动将数据类型转换成需要的精度。
例如: 例如: