[formart] format fixed for kernel\cuda_native codes (#335)

pull/394/head
ExtremeViscent 2022-03-09 01:44:20 +00:00 committed by Frank Lee
parent 00670c870e
commit eaac03ae1d
2 changed files with 15 additions and 17 deletions

View File

@ -37,10 +37,10 @@ class FusedLayerNormAffineFunction(torch.autograd.Function):
input_, weight_, bias_, mean, invvar = ctx.saved_tensors
grad_input = grad_weight = grad_bias = None
grad_input, grad_weight, grad_bias \
= colossal_layer_norm_cuda.backward_affine(
grad_output.contiguous(), mean, invvar,
input_, ctx.normalized_shape,
weight_, bias_, ctx.eps)
= colossal_layer_norm_cuda.backward_affine(
grad_output.contiguous(), mean, invvar,
input_, ctx.normalized_shape,
weight_, bias_, ctx.eps)
return grad_input, grad_weight, grad_bias, None, None

View File

@ -9,7 +9,7 @@ from torch.autograd import Function
def check_config(config):
if config.hidden_size % config.nhead != 0:
raise Exception(f"hidden_size % nhead != 0")
raise Exception("hidden_size % nhead != 0")
factor = 8 if config.fp16 else 4
upbound = factor * 1024 * 4
@ -215,15 +215,14 @@ class MultiHeadAttention(nn.Module):
with torch.no_grad():
self.in_proj_weight.copy_(
attn_qkvw_global.view(3, hs, hs)[:,
int(hs * rank_in_pg /
self.pg_size):int(hs * (rank_in_pg + 1) /
self.pg_size), :])
attn_qkvw_global.view(3, hs, hs)[
:, int(hs * rank_in_pg / self.pg_size):
int(hs * (rank_in_pg + 1) / self.pg_size),
:])
self.in_proj_bias.copy_(
attn_qkvb_global.view(3, hs)[:,
int(hs * rank_in_pg /
self.pg_size):int(hs * (rank_in_pg + 1) /
self.pg_size)])
attn_qkvb_global.view(3, hs)[
:, int(hs * rank_in_pg / self.pg_size):
int(hs * (rank_in_pg + 1) / self.pg_size)])
attn_ow_global = torch.empty(hs, hs)
nn.init.xavier_uniform_(attn_ow_global, 1.0)
@ -231,10 +230,9 @@ class MultiHeadAttention(nn.Module):
torch.distributed.broadcast(attn_ow_global, src=0, group=self.pg)
attn_ow_global = attn_ow_global.cpu()
with torch.no_grad():
self.out_proj_weight.copy_(attn_ow_global[:,
int(hs * rank_in_pg /
self.pg_size):int(hs * (rank_in_pg + 1) /
self.pg_size)])
self.out_proj_weight.copy_(attn_ow_global[
:, int(hs * rank_in_pg / self.pg_size):
int(hs * (rank_in_pg + 1) / self.pg_size)])
else:
attn_qkvw = self.in_proj_weight.view(-1, hs)