feat(*): support fp32 training (#155)

* support float32 training

* fix lint

* add adaptation in model/utils.py

* remove some unnecessary code

* fix lint

* feat(optim): add support for fp32 zero

* Revert "Merge pull request #2 from SolenoidWGT/fp32_zero"

This reverts commit 53fc50b0e5, reversing
changes made to 40f24d0a73.

revert commit

* merge develop

* Update utils.py

* support fp32 in zero optimizer

* modify the dtype

---------

Co-authored-by: wangguoteng.p <wangguoteng925@qq.com>
pull/161/head
ytxiong 2023-08-04 16:05:30 +08:00 committed by GitHub
parent 0268d8eda1
commit 853becfb6e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 142 additions and 19 deletions

View File

@ -162,8 +162,22 @@ def args_sanity_check():
gpc.config.model.dtype = torch.bfloat16
elif gpc.config.model.dtype in ("torch.float16", "torch.half"):
gpc.config.model.dtype = torch.float16
elif gpc.config.model.dtype == "torch.float32":
assert gpc.config.model.use_flash_attn is False, "when using float32, the use_flash_attn must be False"
gpc.config.model.dtype = torch.float32
elif gpc.config.model.dtype == "torch.tf32":
assert gpc.config.model.use_flash_attn is False, "when using tf32, the use_flash_attn must be False"
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
gpc.config.model.dtype = torch.float32
else:
assert gpc.config.model.dtype in ["torch.float16", "torch.half", "torch.bfloat16"]
assert gpc.config.model.dtype in [
"torch.float16",
"torch.half",
"torch.bfloat16",
"torch.float32",
"torch.tf32",
]
if gpc.is_rank_for_log():
logger.info("+" * 15 + " Model Info " + "+" * 15) # pylint: disable=W1201

View File

@ -5,15 +5,13 @@ from typing import Optional
import torch
import torch.nn.functional as F
from flash_attn.ops.fused_dense import (
ColumnParallelLinear,
RowParallelLinear,
fused_dense_func,
)
from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear
from flash_attn.utils.distributed import all_reduce, reduce_scatter
from torch import nn
from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode
from internlm.core.context import global_context as gpc
from internlm.model.utils import fused_dense_func_torch
class ScaleColumnParallelLinear(nn.Linear):
@ -61,7 +59,7 @@ class ScaleColumnParallelLinear(nn.Linear):
weight = self.weight * self.weight_scale + (1 - self.weight_scale) * self.weight.detach()
else:
weight = self.weight
return fused_dense_func(
return fused_dense_func_torch(
input, weight, self.bias, process_group=self.process_group, sequence_parallel=self.sequence_parallel
)
@ -107,11 +105,33 @@ class RewardModelLinear(ScaleColumnParallelLinear):
weight = self.weight * self.weight_scale + (1 - self.weight_scale) * self.weight.detach()
else:
weight = self.weight
return fused_dense_func(
return fused_dense_func_torch(
input, weight, self.bias, process_group=self.process_group, sequence_parallel=self.sequence_parallel
)
class ColumnParallelLinearTorch(ColumnParallelLinear):
def forward(self, x):
# If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
# we do an all_gather of x before doing the matmul.
# If not, then the input is already gathered.
return fused_dense_func_torch(
x, self.weight, self.bias, process_group=self.process_group, sequence_parallel=self.sequence_parallel
)
class RowParallelLinearTorch(RowParallelLinear):
def forward(self, x):
"""
We're doing Tensor Parallel with sequence parallelism: we do the matmul and then
a reduce_scatter of the result.
"""
out = fused_dense_func_torch(x, self.weight, self.bias)
reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
return reduce_fn(out, self.process_group)
class FeedForward(nn.Module):
"""
FeedForward.
@ -143,7 +163,7 @@ class FeedForward(nn.Module):
hidden_features = multiple_of * ((hidden_features + multiple_of - 1) // multiple_of)
self.w1 = ColumnParallelLinear(
self.w1 = ColumnParallelLinearTorch(
in_features,
hidden_features,
process_group,
@ -152,10 +172,10 @@ class FeedForward(nn.Module):
device=device,
dtype=dtype,
)
self.w2 = ColumnParallelLinear(
self.w2 = ColumnParallelLinearTorch(
in_features, hidden_features, process_group, bias, sequence_parallel=False, device=device, dtype=dtype
)
self.w3 = RowParallelLinear(
self.w3 = RowParallelLinearTorch(
hidden_features,
out_features,
process_group,

View File

@ -12,12 +12,12 @@ from flash_attn.modules.mha import (
SelfAttention,
_update_kv_cache,
)
from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear
from torch import nn
from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode
from internlm.core.context import global_context as gpc
from internlm.model.embedding import RotaryEmbedding
from internlm.model.linear import ColumnParallelLinearTorch, RowParallelLinearTorch
class MHA(nn.Module):
@ -78,7 +78,7 @@ class MHA(nn.Module):
self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, scale_base=rotary_emb_scale_base, device=device)
# notice here should change bias=True
self.Wqkv = ColumnParallelLinear(
self.Wqkv = ColumnParallelLinearTorch(
embed_dim,
3 * embed_dim,
process_group,
@ -95,7 +95,7 @@ class MHA(nn.Module):
)
# output projection always have the bias (for now)
self.out_proj = RowParallelLinear(
self.out_proj = RowParallelLinearTorch(
embed_dim, embed_dim, process_group, sequence_parallel=sequence_parallel, **factory_kwargs
)
# need to assign tp attribute so that internlm know it is tensor parallel module

View File

@ -1,7 +1,19 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from typing import Optional
import torch
import torch.nn.functional as F
from flash_attn.ops.fused_dense import FusedDenseFunc
from flash_attn.utils.distributed import (
all_gather_raw,
all_reduce_raw,
reduce_scatter_raw,
)
from torch import Tensor
from torch.cuda.amp import custom_bwd
from torch.distributed import ProcessGroup
from internlm.core.context import global_context as gpc
@ -72,6 +84,78 @@ class _GatherForwardSplitBackward(torch.autograd.Function):
def gather_forward_split_backward(input_, parallel_mode, dim):
return _GatherForwardSplitBackward.apply(input_, parallel_mode, dim)
def linear_bias_wgrad_torch(input, grad_output, has_d_bias):
assert input.dtype == grad_output.dtype
grad_weight = torch.matmul(grad_output.t(), input)
grad_bias = grad_output.sum(dim=0) if has_d_bias else None
return grad_weight, grad_bias
# adpated from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/fused_dense.py
class FusedDenseFuncTorch(FusedDenseFunc):
@staticmethod
@custom_bwd
def backward(ctx, grad_output, *args):
grad_output = grad_output.contiguous()
if ctx.return_residual:
(grad_input,) = args
grad_input = grad_input.contiguous()
process_group = ctx.process_group
sequence_parallel = ctx.sequence_parallel
if ctx.compute_weight_gradient:
x, weight = ctx.saved_tensors
if process_group is not None and sequence_parallel:
total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
else:
total_x = x
else:
(weight,) = ctx.saved_tensors
total_x = None
batch_shape = grad_output.shape[:-1]
batch_dim = batch_shape.numel()
grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
if ctx.needs_input_grad[0]:
if not ctx.return_residual:
grad_input = F.linear(grad_output, weight.t())
else:
grad_input = torch.addmm(grad_input.reshape(batch_dim, grad_input.shape[-1]), grad_output, weight)
grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
if process_group is not None:
reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw
grad_input, handle_grad_input = reduce_fn(grad_input, process_group, async_op=True)
else:
grad_input = None
if ctx.needs_input_grad[1]:
assert ctx.compute_weight_gradient
if process_group is not None and sequence_parallel:
handle_x.wait()
# we remove the cuda independence, which is different from flash_attn.
grad_weight, grad_bias = linear_bias_wgrad_torch(
total_x.reshape(batch_dim, total_x.shape[-1]), grad_output, ctx.needs_input_grad[2]
)
else:
grad_weight = None
grad_bias = grad_output if ctx.needs_input_grad[2] else None
if process_group is not None and ctx.needs_input_grad[0]:
handle_grad_input.wait()
return grad_input, grad_weight, grad_bias, None, None, None
def fused_dense_func_torch(
x: Tensor,
weight: Tensor,
bias: Optional[Tensor] = None,
return_residual: bool = False,
process_group: Optional[ProcessGroup] = None,
sequence_parallel: bool = True,
):
dtype_eligible = x.dtype in [torch.float16, torch.bfloat16] or (
x.dtype == torch.float32 and torch.is_autocast_enabled()
)
if x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and dtype_eligible:
return FusedDenseFunc.apply(x, weight, bias, return_residual, process_group, sequence_parallel)
else:
return FusedDenseFuncTorch.apply(x, weight, bias, return_residual, process_group, sequence_parallel)
def try_import_RMSNorm():
"""

View File

@ -89,6 +89,9 @@ class HybridZeroOptimizer(BaseOptimizer):
zero_cfg: Config = None,
):
# DynamicGradScaler related args
if gpc.config.model.dtype is torch.float32:
initial_scale = 1
else:
initial_scale = grad_scal_cfg.fp16.initial_scale
min_scale = grad_scal_cfg.fp16.min_scale
growth_interval = grad_scal_cfg.fp16.growth_interval
@ -533,6 +536,7 @@ class HybridZeroOptimizer(BaseOptimizer):
norm_groups.append(norm_group)
loss_scale = float(self.loss_scale.item()) # backup
if not gpc.config.model.dtype is torch.float32:
self.grad_scaler.update(found_inf)
# update loss scale if overflow occurs
if found_inf:
@ -576,6 +580,7 @@ class HybridZeroOptimizer(BaseOptimizer):
global_norm = sum(norm_groups) ** 0.5
# the following operations are performed only on the rank to which parameters are assigned.
if not gpc.config.model.dtype is torch.float32:
if len(single_grad_partition_groups) != 0:
self._unscale_and_clip_grads(single_grad_partition_groups, global_norm, loss_scale)