feat(*): support fp32 training (#155)

* support float32 training

* fix lint

* add adaptation in model/utils.py

* remove some unnecessary code

* fix lint

* feat(optim): add support for fp32 zero

* Revert "Merge pull request #2 from SolenoidWGT/fp32_zero"

This reverts commit 53fc50b0e5, reversing
changes made to 40f24d0a73.

revert commit

* merge develop

* Update utils.py

* support fp32 in zero optimizer

* modify the dtype

---------

Co-authored-by: wangguoteng.p <wangguoteng925@qq.com>
pull/161/head
ytxiong 2023-08-04 16:05:30 +08:00 committed by GitHub
parent 0268d8eda1
commit 853becfb6e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 142 additions and 19 deletions

View File

@ -162,8 +162,22 @@ def args_sanity_check():
gpc.config.model.dtype = torch.bfloat16 gpc.config.model.dtype = torch.bfloat16
elif gpc.config.model.dtype in ("torch.float16", "torch.half"): elif gpc.config.model.dtype in ("torch.float16", "torch.half"):
gpc.config.model.dtype = torch.float16 gpc.config.model.dtype = torch.float16
elif gpc.config.model.dtype == "torch.float32":
assert gpc.config.model.use_flash_attn is False, "when using float32, the use_flash_attn must be False"
gpc.config.model.dtype = torch.float32
elif gpc.config.model.dtype == "torch.tf32":
assert gpc.config.model.use_flash_attn is False, "when using tf32, the use_flash_attn must be False"
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
gpc.config.model.dtype = torch.float32
else: else:
assert gpc.config.model.dtype in ["torch.float16", "torch.half", "torch.bfloat16"] assert gpc.config.model.dtype in [
"torch.float16",
"torch.half",
"torch.bfloat16",
"torch.float32",
"torch.tf32",
]
if gpc.is_rank_for_log(): if gpc.is_rank_for_log():
logger.info("+" * 15 + " Model Info " + "+" * 15) # pylint: disable=W1201 logger.info("+" * 15 + " Model Info " + "+" * 15) # pylint: disable=W1201

View File

@ -5,15 +5,13 @@ from typing import Optional
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from flash_attn.ops.fused_dense import ( from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear
ColumnParallelLinear, from flash_attn.utils.distributed import all_reduce, reduce_scatter
RowParallelLinear,
fused_dense_func,
)
from torch import nn from torch import nn
from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode
from internlm.core.context import global_context as gpc from internlm.core.context import global_context as gpc
from internlm.model.utils import fused_dense_func_torch
class ScaleColumnParallelLinear(nn.Linear): class ScaleColumnParallelLinear(nn.Linear):
@ -61,7 +59,7 @@ class ScaleColumnParallelLinear(nn.Linear):
weight = self.weight * self.weight_scale + (1 - self.weight_scale) * self.weight.detach() weight = self.weight * self.weight_scale + (1 - self.weight_scale) * self.weight.detach()
else: else:
weight = self.weight weight = self.weight
return fused_dense_func( return fused_dense_func_torch(
input, weight, self.bias, process_group=self.process_group, sequence_parallel=self.sequence_parallel input, weight, self.bias, process_group=self.process_group, sequence_parallel=self.sequence_parallel
) )
@ -107,11 +105,33 @@ class RewardModelLinear(ScaleColumnParallelLinear):
weight = self.weight * self.weight_scale + (1 - self.weight_scale) * self.weight.detach() weight = self.weight * self.weight_scale + (1 - self.weight_scale) * self.weight.detach()
else: else:
weight = self.weight weight = self.weight
return fused_dense_func( return fused_dense_func_torch(
input, weight, self.bias, process_group=self.process_group, sequence_parallel=self.sequence_parallel input, weight, self.bias, process_group=self.process_group, sequence_parallel=self.sequence_parallel
) )
class ColumnParallelLinearTorch(ColumnParallelLinear):
def forward(self, x):
# If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
# we do an all_gather of x before doing the matmul.
# If not, then the input is already gathered.
return fused_dense_func_torch(
x, self.weight, self.bias, process_group=self.process_group, sequence_parallel=self.sequence_parallel
)
class RowParallelLinearTorch(RowParallelLinear):
def forward(self, x):
"""
We're doing Tensor Parallel with sequence parallelism: we do the matmul and then
a reduce_scatter of the result.
"""
out = fused_dense_func_torch(x, self.weight, self.bias)
reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
return reduce_fn(out, self.process_group)
class FeedForward(nn.Module): class FeedForward(nn.Module):
""" """
FeedForward. FeedForward.
@ -143,7 +163,7 @@ class FeedForward(nn.Module):
hidden_features = multiple_of * ((hidden_features + multiple_of - 1) // multiple_of) hidden_features = multiple_of * ((hidden_features + multiple_of - 1) // multiple_of)
self.w1 = ColumnParallelLinear( self.w1 = ColumnParallelLinearTorch(
in_features, in_features,
hidden_features, hidden_features,
process_group, process_group,
@ -152,10 +172,10 @@ class FeedForward(nn.Module):
device=device, device=device,
dtype=dtype, dtype=dtype,
) )
self.w2 = ColumnParallelLinear( self.w2 = ColumnParallelLinearTorch(
in_features, hidden_features, process_group, bias, sequence_parallel=False, device=device, dtype=dtype in_features, hidden_features, process_group, bias, sequence_parallel=False, device=device, dtype=dtype
) )
self.w3 = RowParallelLinear( self.w3 = RowParallelLinearTorch(
hidden_features, hidden_features,
out_features, out_features,
process_group, process_group,

View File

@ -12,12 +12,12 @@ from flash_attn.modules.mha import (
SelfAttention, SelfAttention,
_update_kv_cache, _update_kv_cache,
) )
from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear
from torch import nn from torch import nn
from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode
from internlm.core.context import global_context as gpc from internlm.core.context import global_context as gpc
from internlm.model.embedding import RotaryEmbedding from internlm.model.embedding import RotaryEmbedding
from internlm.model.linear import ColumnParallelLinearTorch, RowParallelLinearTorch
class MHA(nn.Module): class MHA(nn.Module):
@ -78,7 +78,7 @@ class MHA(nn.Module):
self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, scale_base=rotary_emb_scale_base, device=device) self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, scale_base=rotary_emb_scale_base, device=device)
# notice here should change bias=True # notice here should change bias=True
self.Wqkv = ColumnParallelLinear( self.Wqkv = ColumnParallelLinearTorch(
embed_dim, embed_dim,
3 * embed_dim, 3 * embed_dim,
process_group, process_group,
@ -95,7 +95,7 @@ class MHA(nn.Module):
) )
# output projection always have the bias (for now) # output projection always have the bias (for now)
self.out_proj = RowParallelLinear( self.out_proj = RowParallelLinearTorch(
embed_dim, embed_dim, process_group, sequence_parallel=sequence_parallel, **factory_kwargs embed_dim, embed_dim, process_group, sequence_parallel=sequence_parallel, **factory_kwargs
) )
# need to assign tp attribute so that internlm know it is tensor parallel module # need to assign tp attribute so that internlm know it is tensor parallel module

View File

@ -1,7 +1,19 @@
#!/usr/bin/env python #!/usr/bin/env python
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
from typing import Optional
import torch import torch
import torch.nn.functional as F
from flash_attn.ops.fused_dense import FusedDenseFunc
from flash_attn.utils.distributed import (
all_gather_raw,
all_reduce_raw,
reduce_scatter_raw,
)
from torch import Tensor
from torch.cuda.amp import custom_bwd
from torch.distributed import ProcessGroup
from internlm.core.context import global_context as gpc from internlm.core.context import global_context as gpc
@ -72,6 +84,78 @@ class _GatherForwardSplitBackward(torch.autograd.Function):
def gather_forward_split_backward(input_, parallel_mode, dim): def gather_forward_split_backward(input_, parallel_mode, dim):
return _GatherForwardSplitBackward.apply(input_, parallel_mode, dim) return _GatherForwardSplitBackward.apply(input_, parallel_mode, dim)
def linear_bias_wgrad_torch(input, grad_output, has_d_bias):
assert input.dtype == grad_output.dtype
grad_weight = torch.matmul(grad_output.t(), input)
grad_bias = grad_output.sum(dim=0) if has_d_bias else None
return grad_weight, grad_bias
# adpated from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/fused_dense.py
class FusedDenseFuncTorch(FusedDenseFunc):
@staticmethod
@custom_bwd
def backward(ctx, grad_output, *args):
grad_output = grad_output.contiguous()
if ctx.return_residual:
(grad_input,) = args
grad_input = grad_input.contiguous()
process_group = ctx.process_group
sequence_parallel = ctx.sequence_parallel
if ctx.compute_weight_gradient:
x, weight = ctx.saved_tensors
if process_group is not None and sequence_parallel:
total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
else:
total_x = x
else:
(weight,) = ctx.saved_tensors
total_x = None
batch_shape = grad_output.shape[:-1]
batch_dim = batch_shape.numel()
grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
if ctx.needs_input_grad[0]:
if not ctx.return_residual:
grad_input = F.linear(grad_output, weight.t())
else:
grad_input = torch.addmm(grad_input.reshape(batch_dim, grad_input.shape[-1]), grad_output, weight)
grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
if process_group is not None:
reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw
grad_input, handle_grad_input = reduce_fn(grad_input, process_group, async_op=True)
else:
grad_input = None
if ctx.needs_input_grad[1]:
assert ctx.compute_weight_gradient
if process_group is not None and sequence_parallel:
handle_x.wait()
# we remove the cuda independence, which is different from flash_attn.
grad_weight, grad_bias = linear_bias_wgrad_torch(
total_x.reshape(batch_dim, total_x.shape[-1]), grad_output, ctx.needs_input_grad[2]
)
else:
grad_weight = None
grad_bias = grad_output if ctx.needs_input_grad[2] else None
if process_group is not None and ctx.needs_input_grad[0]:
handle_grad_input.wait()
return grad_input, grad_weight, grad_bias, None, None, None
def fused_dense_func_torch(
x: Tensor,
weight: Tensor,
bias: Optional[Tensor] = None,
return_residual: bool = False,
process_group: Optional[ProcessGroup] = None,
sequence_parallel: bool = True,
):
dtype_eligible = x.dtype in [torch.float16, torch.bfloat16] or (
x.dtype == torch.float32 and torch.is_autocast_enabled()
)
if x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and dtype_eligible:
return FusedDenseFunc.apply(x, weight, bias, return_residual, process_group, sequence_parallel)
else:
return FusedDenseFuncTorch.apply(x, weight, bias, return_residual, process_group, sequence_parallel)
def try_import_RMSNorm(): def try_import_RMSNorm():
""" """
@ -86,4 +170,4 @@ def try_import_RMSNorm():
logger = get_logger(__file__) logger = get_logger(__file__)
logger.warn("The torch implementation for MixFusedRMSNorm is slower than apex. Please note this!") logger.warn("The torch implementation for MixFusedRMSNorm is slower than apex. Please note this!")
from internlm.model.norm import RMSNormTorch as RMSNorm from internlm.model.norm import RMSNormTorch as RMSNorm
return RMSNorm return RMSNorm

View File

@ -89,7 +89,10 @@ class HybridZeroOptimizer(BaseOptimizer):
zero_cfg: Config = None, zero_cfg: Config = None,
): ):
# DynamicGradScaler related args # DynamicGradScaler related args
initial_scale = grad_scal_cfg.fp16.initial_scale if gpc.config.model.dtype is torch.float32:
initial_scale = 1
else:
initial_scale = grad_scal_cfg.fp16.initial_scale
min_scale = grad_scal_cfg.fp16.min_scale min_scale = grad_scal_cfg.fp16.min_scale
growth_interval = grad_scal_cfg.fp16.growth_interval growth_interval = grad_scal_cfg.fp16.growth_interval
growth_factor = grad_scal_cfg.growth_factor growth_factor = grad_scal_cfg.growth_factor
@ -533,7 +536,8 @@ class HybridZeroOptimizer(BaseOptimizer):
norm_groups.append(norm_group) norm_groups.append(norm_group)
loss_scale = float(self.loss_scale.item()) # backup loss_scale = float(self.loss_scale.item()) # backup
self.grad_scaler.update(found_inf) if not gpc.config.model.dtype is torch.float32:
self.grad_scaler.update(found_inf)
# update loss scale if overflow occurs # update loss scale if overflow occurs
if found_inf: if found_inf:
if gpc.is_rank_for_log(): if gpc.is_rank_for_log():
@ -576,8 +580,9 @@ class HybridZeroOptimizer(BaseOptimizer):
global_norm = sum(norm_groups) ** 0.5 global_norm = sum(norm_groups) ** 0.5
# the following operations are performed only on the rank to which parameters are assigned. # the following operations are performed only on the rank to which parameters are assigned.
if len(single_grad_partition_groups) != 0: if not gpc.config.model.dtype is torch.float32:
self._unscale_and_clip_grads(single_grad_partition_groups, global_norm, loss_scale) if len(single_grad_partition_groups) != 0:
self._unscale_and_clip_grads(single_grad_partition_groups, global_norm, loss_scale)
timer("cal_norm").stop() timer("cal_norm").stop()
# update the parameters # update the parameters