mirror of https://github.com/hpcaitech/ColossalAI
30 lines
1005 B
Python
Executable File
30 lines
1005 B
Python
Executable File
import math
|
|
|
|
import numpy as np
|
|
import torch.nn as nn
|
|
|
|
|
|
def glorot_uniform_af(x, gain=1.0):
|
|
"""
|
|
initialize tensors the same as xavier_initializer in PyTorch, but the dimensions are different:
|
|
In PyTorch:
|
|
[feature_out, feature_in, n_head ...]
|
|
In Jax:
|
|
[... n_head, feature_in, feature_out]
|
|
However, there is a feature in original Alphafold2 code that they use the Jax version initializer to initialize tensors like:
|
|
[feature_in, n_head, feature_out]
|
|
|
|
In this function, we keep this feature to initialize [feature_in, n_head, ..., feature_out] tensors
|
|
"""
|
|
fan_in, fan_out = x.shape[-2:]
|
|
if len(x.shape) > 2:
|
|
receptive_field_size = np.prod(x.shape[:-2])
|
|
fan_in *= receptive_field_size
|
|
fan_out *= receptive_field_size
|
|
std = gain * math.sqrt(2.0 / float(fan_in + fan_out))
|
|
dev = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
|
|
|
|
nn.init.uniform_(x, -dev, dev)
|
|
|
|
return x
|