@ -646,48 +646,49 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule):
def _load_from_state_dict (
self , state_dict , prefix , local_metadata , strict , missing_keys , unexpected_keys , error_msgs
) :
# NOTE This is a hack to ensure we could load the right weight from LlamaAttention checkpoint due to the use of torch.stack(q_weight, k_weight, v_weight)
for hook in self . _load_state_dict_pre_hooks . values ( ) :
hook ( state_dict , prefix , local_metadata , strict , missing_keys , unexpected_keys , error_msgs )
persistent_buffers = { k : v for k , v in self . _buffers . items ( ) if k not in self . _non_persistent_buffers_set }
local_name_params = itertools . chain ( self . _parameters . items ( ) , persistent_buffers . items ( ) )
local_state = { k : v for k , v in local_name_params if v is not None }
key = " qkv_weight "
k1 = " q_proj.weight "
k2 = " k_proj.weight "
k3 = " v_proj.weight "
q_w = state_dict [ prefix + k1 ]
k_w = state_dict [ prefix + k2 ]
v_w = state_dict [ prefix + k3 ]
device_mesh = self . helper_layout . device_mesh
sharding_spec = self . helper_layout . sharding_spec
q_w = distribute_tensor ( q_w , device_mesh , sharding_spec )
k_w = distribute_tensor ( k_w , device_mesh , sharding_spec )
v_w = distribute_tensor ( v_w , device_mesh , sharding_spec )
qkv_w = torch . stack ( [ q_w . T , k_w . T , v_w . T ] , dim = 0 )
input_param = nn . Parameter (
qkv_w
) # NOTE qkv_weight doesn't have to be a distensor, Like input_param = sharded_tensor_to_param(input_param)
param = local_state [ key ]
try :
with torch . no_grad ( ) :
param . copy_ ( input_param )
except Exception as ex :
error_msgs . append (
' While copying the parameter named " {} " , '
" whose dimensions in the model are {} and "
" whose dimensions in the checkpoint are {} , "
" an exception occurred : {} . " . format ( key , param . size ( ) , input_param . size ( ) , ex . args )
)
if self . num_heads == self . num_key_value_heads :
# NOTE This is a hack to ensure we could load the right weight from LlamaAttention checkpoint due to the use of torch.stack(q_weight, k_weight, v_weight)
for hook in self . _load_state_dict_pre_hooks . values ( ) :
hook ( state_dict , prefix , local_metadata , strict , missing_keys , unexpected_keys , error_msgs )
persistent_buffers = { k : v for k , v in self . _buffers . items ( ) if k not in self . _non_persistent_buffers_set }
local_name_params = itertools . chain ( self . _parameters . items ( ) , persistent_buffers . items ( ) )
local_state = { k : v for k , v in local_name_params if v is not None }
key = " qkv_weight "
k1 = " q_proj.weight "
k2 = " k_proj.weight "
k3 = " v_proj.weight "
q_w = state_dict [ prefix + k1 ]
k_w = state_dict [ prefix + k2 ]
v_w = state_dict [ prefix + k3 ]
device_mesh = self . helper_layout . device_mesh
sharding_spec = self . helper_layout . sharding_spec
q_w = distribute_tensor ( q_w , device_mesh , sharding_spec )
k_w = distribute_tensor ( k_w , device_mesh , sharding_spec )
v_w = distribute_tensor ( v_w , device_mesh , sharding_spec )
qkv_w = torch . stack ( [ q_w . T , k_w . T , v_w . T ] , dim = 0 )
input_param = nn . Parameter (
qkv_w
) # NOTE qkv_weight doesn't have to be a distensor, Like input_param = sharded_tensor_to_param(input_param)
param = local_state [ key ]
try :
with torch . no_grad ( ) :
param . copy_ ( input_param )
except Exception as ex :
error_msgs . append (
' While copying the parameter named " {} " , '
" whose dimensions in the model are {} and "
" whose dimensions in the checkpoint are {} , "
" an exception occurred : {} . " . format ( key , param . size ( ) , input_param . size ( ) , ex . args )
)
strict = False # to avoid unexpected_keys
strict = False # to avoid unexpected_keys
super ( ) . _load_from_state_dict (
state_dict , prefix , local_metadata , strict , missing_keys , unexpected_keys , error_msgs
)