mirror of https://github.com/hpcaitech/ColossalAI
[formart] format fixed for kernel\cuda_native codes (#335)
parent
00670c870e
commit
eaac03ae1d
|
@ -37,10 +37,10 @@ class FusedLayerNormAffineFunction(torch.autograd.Function):
|
|||
input_, weight_, bias_, mean, invvar = ctx.saved_tensors
|
||||
grad_input = grad_weight = grad_bias = None
|
||||
grad_input, grad_weight, grad_bias \
|
||||
= colossal_layer_norm_cuda.backward_affine(
|
||||
grad_output.contiguous(), mean, invvar,
|
||||
input_, ctx.normalized_shape,
|
||||
weight_, bias_, ctx.eps)
|
||||
= colossal_layer_norm_cuda.backward_affine(
|
||||
grad_output.contiguous(), mean, invvar,
|
||||
input_, ctx.normalized_shape,
|
||||
weight_, bias_, ctx.eps)
|
||||
|
||||
return grad_input, grad_weight, grad_bias, None, None
|
||||
|
||||
|
|
|
@ -9,7 +9,7 @@ from torch.autograd import Function
|
|||
|
||||
def check_config(config):
|
||||
if config.hidden_size % config.nhead != 0:
|
||||
raise Exception(f"hidden_size % nhead != 0")
|
||||
raise Exception("hidden_size % nhead != 0")
|
||||
|
||||
factor = 8 if config.fp16 else 4
|
||||
upbound = factor * 1024 * 4
|
||||
|
@ -215,15 +215,14 @@ class MultiHeadAttention(nn.Module):
|
|||
|
||||
with torch.no_grad():
|
||||
self.in_proj_weight.copy_(
|
||||
attn_qkvw_global.view(3, hs, hs)[:,
|
||||
int(hs * rank_in_pg /
|
||||
self.pg_size):int(hs * (rank_in_pg + 1) /
|
||||
self.pg_size), :])
|
||||
attn_qkvw_global.view(3, hs, hs)[
|
||||
:, int(hs * rank_in_pg / self.pg_size):
|
||||
int(hs * (rank_in_pg + 1) / self.pg_size),
|
||||
:])
|
||||
self.in_proj_bias.copy_(
|
||||
attn_qkvb_global.view(3, hs)[:,
|
||||
int(hs * rank_in_pg /
|
||||
self.pg_size):int(hs * (rank_in_pg + 1) /
|
||||
self.pg_size)])
|
||||
attn_qkvb_global.view(3, hs)[
|
||||
:, int(hs * rank_in_pg / self.pg_size):
|
||||
int(hs * (rank_in_pg + 1) / self.pg_size)])
|
||||
|
||||
attn_ow_global = torch.empty(hs, hs)
|
||||
nn.init.xavier_uniform_(attn_ow_global, 1.0)
|
||||
|
@ -231,10 +230,9 @@ class MultiHeadAttention(nn.Module):
|
|||
torch.distributed.broadcast(attn_ow_global, src=0, group=self.pg)
|
||||
attn_ow_global = attn_ow_global.cpu()
|
||||
with torch.no_grad():
|
||||
self.out_proj_weight.copy_(attn_ow_global[:,
|
||||
int(hs * rank_in_pg /
|
||||
self.pg_size):int(hs * (rank_in_pg + 1) /
|
||||
self.pg_size)])
|
||||
self.out_proj_weight.copy_(attn_ow_global[
|
||||
:, int(hs * rank_in_pg / self.pg_size):
|
||||
int(hs * (rank_in_pg + 1) / self.pg_size)])
|
||||
|
||||
else:
|
||||
attn_qkvw = self.in_proj_weight.view(-1, hs)
|
||||
|
|
Loading…
Reference in New Issue