mirror of https://github.com/hpcaitech/ColossalAI
[model checkpoint] updated saving/loading for 3d layers (#597)
parent
93089ed708
commit
77ad24bf94
|
@ -1,4 +1,5 @@
|
|||
import math
|
||||
from collections import OrderedDict
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
|
@ -12,13 +13,15 @@ from colossalai.global_variables import tensor_parallel_env as env
|
|||
from colossalai.nn import init as init
|
||||
from colossalai.nn.layer.base_layer import ParallelLayer
|
||||
from colossalai.registry import LAYERS
|
||||
from colossalai.utils.checkpointing import (broadcast_state_dict, gather_tensor_parallel_state_dict,
|
||||
partition_tensor_parallel_state_dict)
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from torch import Tensor
|
||||
from torch.nn import Parameter
|
||||
|
||||
from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
|
||||
from ._operation import layernorm_3d, linear_3d, classifier_3d, split_tensor_3d
|
||||
from ._operation import all_gather_tensor_3d, reduce_scatter_tensor_3d, broadcast_weight_3d_from_diagonal
|
||||
from ._operation import (all_gather_tensor_3d, broadcast_weight_3d_from_diagonal, classifier_3d, layernorm_3d,
|
||||
linear_3d, reduce_scatter_tensor_3d, split_tensor_3d)
|
||||
from ._utils import get_depth_from_env, get_last_group, get_parallel_mode_from_env, swap_in_out_group
|
||||
|
||||
|
||||
|
@ -61,6 +64,67 @@ class LayerNorm3D(ParallelLayer):
|
|||
init.zeros_()(self.bias)
|
||||
init.ones_()(self.weight)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
# weight
|
||||
weight = state_dict.pop(weight_key, None)
|
||||
if weight is not None:
|
||||
local_state[weight_key] = weight.transpose(0, 1)
|
||||
# bias
|
||||
bias = state_dict.pop(bias_key, None)
|
||||
if bias is not None:
|
||||
local_state[bias_key] = bias
|
||||
|
||||
# partition in output groups
|
||||
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
|
||||
gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||
local_state = partition_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
self.output_parallel_mode,
|
||||
dims={
|
||||
weight_key: 0,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: True,
|
||||
},
|
||||
)
|
||||
# broadcast in input groups
|
||||
if gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||
local_state = broadcast_state_dict(local_state, self.input_parallel_mode)
|
||||
# broadcast in weight groups
|
||||
local_state = broadcast_state_dict(local_state, self.weight_parallel_mode)
|
||||
|
||||
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
local_state = OrderedDict({weight_key: self.weight, bias_key: self.bias})
|
||||
|
||||
# gather in output groups
|
||||
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
|
||||
gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||
local_state = gather_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
self.output_parallel_mode,
|
||||
dims={
|
||||
weight_key: 0,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: True
|
||||
},
|
||||
keep_vars=keep_vars,
|
||||
)
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
destination.update(local_state)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
return layernorm_3d(input_, self.weight, self.bias, self.normalized_shape, self.variance_epsilon,
|
||||
self.input_parallel_mode, self.weight_parallel_mode, self.output_parallel_mode)
|
||||
|
@ -135,6 +199,122 @@ class Linear3D(ParallelLayer):
|
|||
broadcast(self.bias, weight_src_rank, self.weight_parallel_mode)
|
||||
broadcast(self.bias, output_src_rank, self.output_parallel_mode)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
# weight
|
||||
weight = state_dict.pop(weight_key, None)
|
||||
if weight is not None:
|
||||
local_state[weight_key] = weight.transpose(0, 1)
|
||||
# bias
|
||||
if self.bias is not None:
|
||||
bias = state_dict.pop(bias_key, None)
|
||||
if bias is not None:
|
||||
local_state[bias_key] = bias
|
||||
|
||||
# partition in output groups
|
||||
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
|
||||
gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||
local_state = partition_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
self.output_parallel_mode,
|
||||
dims={
|
||||
weight_key: 0,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: False
|
||||
},
|
||||
)
|
||||
# partition in input groups
|
||||
if gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||
local_state = partition_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
self.input_parallel_mode,
|
||||
dims={
|
||||
weight_key: -1,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: True
|
||||
},
|
||||
)
|
||||
# partition in weight groups
|
||||
local_state = partition_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
self.weight_parallel_mode,
|
||||
dims={
|
||||
weight_key: -1,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: False
|
||||
},
|
||||
)
|
||||
|
||||
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
local_state = OrderedDict({weight_key: self.weight})
|
||||
if self.bias is not None:
|
||||
local_state[bias_key] = self.bias
|
||||
|
||||
# gather in weight groups
|
||||
local_state = gather_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
self.weight_parallel_mode,
|
||||
dims={
|
||||
weight_key: -1,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: False
|
||||
},
|
||||
keep_vars=keep_vars,
|
||||
)
|
||||
# gather in input groups
|
||||
if gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||
local_state = gather_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
self.input_parallel_mode,
|
||||
dims={
|
||||
weight_key: -1,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: True
|
||||
},
|
||||
keep_vars=keep_vars,
|
||||
)
|
||||
# gather in output groups
|
||||
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
|
||||
gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||
local_state = gather_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
self.output_parallel_mode,
|
||||
dims={
|
||||
weight_key: 0,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: False
|
||||
},
|
||||
keep_vars=keep_vars,
|
||||
)
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
local_state[weight_key] = local_state[weight_key].transpose(0, 1)
|
||||
destination.update(local_state)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
return linear_3d(input_, self.weight, self.bias, self.input_parallel_mode, self.weight_parallel_mode,
|
||||
self.output_parallel_mode)
|
||||
|
@ -212,6 +392,73 @@ class Classifier3D(ParallelLayer):
|
|||
broadcast(self.bias, output_src_rank, self.output_parallel_mode)
|
||||
broadcast(self.bias, input_src_rank, self.input_parallel_mode)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
# weight
|
||||
if self.has_weight:
|
||||
weight = state_dict.pop(weight_key, None)
|
||||
if weight is not None:
|
||||
local_state[weight_key] = weight
|
||||
# bias
|
||||
if self.bias is not None:
|
||||
bias = state_dict.pop(bias_key, None)
|
||||
if bias is not None:
|
||||
local_state[bias_key] = bias
|
||||
|
||||
# partition in output groups
|
||||
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
|
||||
gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||
local_state = partition_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
self.output_parallel_mode,
|
||||
dims={
|
||||
weight_key: -1,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: False
|
||||
},
|
||||
)
|
||||
# broadcast in input groups
|
||||
if gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||
local_state = broadcast_state_dict(local_state, self.input_parallel_mode)
|
||||
# broadcast in weight groups
|
||||
local_state = broadcast_state_dict(local_state, self.weight_parallel_mode)
|
||||
|
||||
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
local_state = OrderedDict()
|
||||
if self.has_weight:
|
||||
local_state[weight_key] = self.weight
|
||||
if self.bias is not None:
|
||||
local_state[bias_key] = self.bias
|
||||
|
||||
# gather in output groups
|
||||
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
|
||||
gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||
local_state = gather_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
self.output_parallel_mode,
|
||||
dims={
|
||||
weight_key: -1,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: False
|
||||
},
|
||||
keep_vars=keep_vars,
|
||||
)
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
destination.update(local_state)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
return classifier_3d(input_, self.weight, self.bias, self.input_parallel_mode, self.weight_parallel_mode,
|
||||
self.output_parallel_mode)
|
||||
|
@ -296,6 +543,122 @@ class VocabParallelClassifier3D(ParallelLayer):
|
|||
broadcast(self.bias, weight_src_rank, self.weight_parallel_mode)
|
||||
broadcast(self.bias, output_src_rank, self.output_parallel_mode)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
# weight
|
||||
if self.has_weight:
|
||||
weight = state_dict.pop(weight_key, None)
|
||||
if weight is not None:
|
||||
local_state[weight_key] = weight
|
||||
# bias
|
||||
if self.bias is not None:
|
||||
bias = state_dict.pop(bias_key, None)
|
||||
if bias is not None:
|
||||
local_state[bias_key] = bias
|
||||
|
||||
# partition in output groups
|
||||
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
|
||||
gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||
local_state = partition_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
self.output_parallel_mode,
|
||||
dims={
|
||||
weight_key: -1,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: False
|
||||
},
|
||||
)
|
||||
# partition in input groups
|
||||
if gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||
local_state = partition_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
self.input_parallel_mode,
|
||||
dims={
|
||||
weight_key: 0,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: True
|
||||
},
|
||||
)
|
||||
# partition in weight groups
|
||||
local_state = partition_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
self.weight_parallel_mode,
|
||||
dims={
|
||||
weight_key: 0,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: False
|
||||
},
|
||||
)
|
||||
|
||||
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
local_state = OrderedDict({weight_key: self.weight})
|
||||
if self.bias is not None:
|
||||
local_state[bias_key] = self.bias
|
||||
|
||||
# gather in weight groups
|
||||
local_state = gather_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
self.weight_parallel_mode,
|
||||
dims={
|
||||
weight_key: 0,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: False
|
||||
},
|
||||
keep_vars=keep_vars,
|
||||
)
|
||||
# gather in input groups
|
||||
if gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||
local_state = gather_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
self.input_parallel_mode,
|
||||
dims={
|
||||
weight_key: 0,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: True
|
||||
},
|
||||
keep_vars=keep_vars,
|
||||
)
|
||||
# gather in output groups
|
||||
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
|
||||
gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||
local_state = gather_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
self.output_parallel_mode,
|
||||
dims={
|
||||
weight_key: -1,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: False
|
||||
},
|
||||
keep_vars=keep_vars,
|
||||
)
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
destination.update(local_state)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
return linear_3d(input_, self.weight.transpose(0, 1), self.bias, self.input_parallel_mode,
|
||||
self.weight_parallel_mode, self.output_parallel_mode)
|
||||
|
@ -392,12 +755,98 @@ class PatchEmbedding3D(ParallelLayer):
|
|||
self.cls_token.register_hook(self._sync_grad_hook)
|
||||
self.pos_embed.register_hook(self._sync_grad_hook)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
cls_token_key = prefix + 'cls_token'
|
||||
pos_embed_key = prefix + 'pos_embed'
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
# weight
|
||||
weight = state_dict.pop(weight_key, None)
|
||||
if weight is not None:
|
||||
local_state[weight_key] = weight
|
||||
# bias
|
||||
bias = state_dict.pop(bias_key, None)
|
||||
if bias is not None:
|
||||
local_state[bias_key] = bias
|
||||
# cls token
|
||||
cls_token = state_dict.pop(cls_token_key, None)
|
||||
if cls_token is not None:
|
||||
local_state[cls_token_key] = cls_token
|
||||
# pos embed
|
||||
pos_embed = state_dict.pop(pos_embed_key, None)
|
||||
if pos_embed is not None:
|
||||
local_state[pos_embed_key] = pos_embed
|
||||
|
||||
# partition in output groups
|
||||
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
|
||||
gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||
local_state = partition_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
self.output_parallel_mode,
|
||||
dims={
|
||||
weight_key: 0,
|
||||
bias_key: 0,
|
||||
cls_token_key: -1,
|
||||
pos_embed_key: -1
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: True,
|
||||
cls_token_key: True,
|
||||
pos_embed_key: True
|
||||
},
|
||||
)
|
||||
# broadcast in input groups
|
||||
if gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||
local_state = broadcast_state_dict(local_state, self.input_parallel_mode)
|
||||
# broadcast in weight groups
|
||||
local_state = broadcast_state_dict(local_state, self.weight_parallel_mode)
|
||||
|
||||
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
cls_token_key = prefix + 'cls_token'
|
||||
pos_embed_key = prefix + 'pos_embed'
|
||||
local_state = OrderedDict({
|
||||
weight_key: self.weight,
|
||||
bias_key: self.bias,
|
||||
cls_token_key: self.cls_token,
|
||||
pos_embed_key: self.pos_embed
|
||||
})
|
||||
|
||||
# gather in output groups
|
||||
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
|
||||
gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||
local_state = gather_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
self.output_parallel_mode,
|
||||
dims={
|
||||
weight_key: 0,
|
||||
bias_key: 0,
|
||||
cls_token_key: -1,
|
||||
pos_embed_key: -1
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: True,
|
||||
cls_token_key: True,
|
||||
pos_embed_key: True
|
||||
},
|
||||
keep_vars=keep_vars,
|
||||
)
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
destination.update(local_state)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
input_ = split_tensor_3d(input_, 0, self.weight_parallel_mode)
|
||||
input_ = split_tensor_3d(input_, 0, self.input_parallel_mode)
|
||||
output = F.conv2d(input_, self.weight, self.bias, stride=self.patch_size)
|
||||
if self.flatten:
|
||||
output = output.flatten(2).transpose(1, 2) # BCHW -> BNC
|
||||
output = output.flatten(2).transpose(1, 2) # BCHW -> BNC
|
||||
|
||||
cls_token = self.cls_token.expand(output.shape[0], -1, -1)
|
||||
output = torch.cat((cls_token, output), dim=1)
|
||||
|
@ -480,6 +929,49 @@ class Embedding3D(ParallelLayer):
|
|||
with torch.no_grad():
|
||||
self.weight[self.padding_idx].fill_(0)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
# weight
|
||||
weight = state_dict.pop(weight_key, None)
|
||||
if weight is not None:
|
||||
local_state[weight_key] = weight
|
||||
|
||||
# partition in output groups
|
||||
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
|
||||
gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||
local_state = partition_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
self.output_parallel_mode,
|
||||
dims={weight_key: 0},
|
||||
partition_states={weight_key: True},
|
||||
)
|
||||
# broadcast in input groups
|
||||
if gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||
local_state = broadcast_state_dict(local_state, self.input_parallel_mode)
|
||||
# broadcast in weight groups
|
||||
local_state = broadcast_state_dict(local_state, self.weight_parallel_mode)
|
||||
|
||||
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
local_state = OrderedDict({weight_key: self.weight})
|
||||
|
||||
# gather in output groups
|
||||
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
|
||||
gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||
local_state = gather_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
self.output_parallel_mode,
|
||||
dims={weight_key: 0},
|
||||
partition_states={weight_key: True},
|
||||
keep_vars=keep_vars,
|
||||
)
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
destination.update(local_state)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
input_ = split_tensor_3d(input_, 0, self.weight_parallel_mode)
|
||||
input_ = split_tensor_3d(input_, 0, self.input_parallel_mode)
|
||||
|
@ -570,6 +1062,76 @@ class VocabParallelEmbedding3D(torch.nn.Module):
|
|||
with torch.no_grad():
|
||||
self.weight[self.padding_idx - self.vocab_start_index].fill_(0)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
# weight
|
||||
weight = state_dict.pop(weight_key, None)
|
||||
if weight is not None:
|
||||
local_state[weight_key] = weight
|
||||
|
||||
# partition in output groups
|
||||
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
|
||||
gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||
local_state = partition_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
self.output_parallel_mode,
|
||||
dims={weight_key: -1},
|
||||
partition_states={weight_key: True},
|
||||
)
|
||||
# partition in input groups
|
||||
if gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||
local_state = partition_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
self.input_parallel_mode,
|
||||
dims={weight_key: 0},
|
||||
partition_states={weight_key: True},
|
||||
)
|
||||
# partition in weight groups
|
||||
local_state = partition_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
self.weight_parallel_mode,
|
||||
dims={weight_key: 0},
|
||||
partition_states={weight_key: True},
|
||||
)
|
||||
|
||||
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
local_state = OrderedDict({weight_key: self.weight})
|
||||
|
||||
# gather in weight groups
|
||||
local_state = gather_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
self.weight_parallel_mode,
|
||||
dims={weight_key: 0},
|
||||
partition_states={weight_key: True},
|
||||
keep_vars=keep_vars,
|
||||
)
|
||||
# gather in input groups
|
||||
if gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||
local_state = gather_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
self.input_parallel_mode,
|
||||
dims={weight_key: 0},
|
||||
partition_states={weight_key: True},
|
||||
keep_vars=keep_vars,
|
||||
)
|
||||
# gather in output groups
|
||||
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
|
||||
gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||
local_state = gather_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
self.output_parallel_mode,
|
||||
dims={weight_key: -1},
|
||||
partition_states={weight_key: True},
|
||||
keep_vars=keep_vars,
|
||||
)
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
destination.update(local_state)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
input_ = split_tensor_3d(input_, 0, self.weight_parallel_mode)
|
||||
|
||||
|
|
Loading…
Reference in New Issue