mirror of https://github.com/hpcaitech/ColossalAI
[model checkpoint] updated saving/loading for 3d layers (#597)
parent
93089ed708
commit
77ad24bf94
|
@ -1,4 +1,5 @@
|
||||||
import math
|
import math
|
||||||
|
from collections import OrderedDict
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
@ -12,13 +13,15 @@ from colossalai.global_variables import tensor_parallel_env as env
|
||||||
from colossalai.nn import init as init
|
from colossalai.nn import init as init
|
||||||
from colossalai.nn.layer.base_layer import ParallelLayer
|
from colossalai.nn.layer.base_layer import ParallelLayer
|
||||||
from colossalai.registry import LAYERS
|
from colossalai.registry import LAYERS
|
||||||
|
from colossalai.utils.checkpointing import (broadcast_state_dict, gather_tensor_parallel_state_dict,
|
||||||
|
partition_tensor_parallel_state_dict)
|
||||||
from colossalai.utils.cuda import get_current_device
|
from colossalai.utils.cuda import get_current_device
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.nn import Parameter
|
from torch.nn import Parameter
|
||||||
|
|
||||||
from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
|
from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
|
||||||
from ._operation import layernorm_3d, linear_3d, classifier_3d, split_tensor_3d
|
from ._operation import (all_gather_tensor_3d, broadcast_weight_3d_from_diagonal, classifier_3d, layernorm_3d,
|
||||||
from ._operation import all_gather_tensor_3d, reduce_scatter_tensor_3d, broadcast_weight_3d_from_diagonal
|
linear_3d, reduce_scatter_tensor_3d, split_tensor_3d)
|
||||||
from ._utils import get_depth_from_env, get_last_group, get_parallel_mode_from_env, swap_in_out_group
|
from ._utils import get_depth_from_env, get_last_group, get_parallel_mode_from_env, swap_in_out_group
|
||||||
|
|
||||||
|
|
||||||
|
@ -61,6 +64,67 @@ class LayerNorm3D(ParallelLayer):
|
||||||
init.zeros_()(self.bias)
|
init.zeros_()(self.bias)
|
||||||
init.ones_()(self.weight)
|
init.ones_()(self.weight)
|
||||||
|
|
||||||
|
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||||
|
local_state = OrderedDict()
|
||||||
|
weight_key = prefix + 'weight'
|
||||||
|
bias_key = prefix + 'bias'
|
||||||
|
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||||
|
# weight
|
||||||
|
weight = state_dict.pop(weight_key, None)
|
||||||
|
if weight is not None:
|
||||||
|
local_state[weight_key] = weight.transpose(0, 1)
|
||||||
|
# bias
|
||||||
|
bias = state_dict.pop(bias_key, None)
|
||||||
|
if bias is not None:
|
||||||
|
local_state[bias_key] = bias
|
||||||
|
|
||||||
|
# partition in output groups
|
||||||
|
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
|
||||||
|
gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||||
|
local_state = partition_tensor_parallel_state_dict(
|
||||||
|
local_state,
|
||||||
|
self.output_parallel_mode,
|
||||||
|
dims={
|
||||||
|
weight_key: 0,
|
||||||
|
bias_key: 0
|
||||||
|
},
|
||||||
|
partition_states={
|
||||||
|
weight_key: True,
|
||||||
|
bias_key: True,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
# broadcast in input groups
|
||||||
|
if gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||||
|
local_state = broadcast_state_dict(local_state, self.input_parallel_mode)
|
||||||
|
# broadcast in weight groups
|
||||||
|
local_state = broadcast_state_dict(local_state, self.weight_parallel_mode)
|
||||||
|
|
||||||
|
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
|
||||||
|
|
||||||
|
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||||
|
weight_key = prefix + 'weight'
|
||||||
|
bias_key = prefix + 'bias'
|
||||||
|
local_state = OrderedDict({weight_key: self.weight, bias_key: self.bias})
|
||||||
|
|
||||||
|
# gather in output groups
|
||||||
|
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
|
||||||
|
gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||||
|
local_state = gather_tensor_parallel_state_dict(
|
||||||
|
local_state,
|
||||||
|
self.output_parallel_mode,
|
||||||
|
dims={
|
||||||
|
weight_key: 0,
|
||||||
|
bias_key: 0
|
||||||
|
},
|
||||||
|
partition_states={
|
||||||
|
weight_key: True,
|
||||||
|
bias_key: True
|
||||||
|
},
|
||||||
|
keep_vars=keep_vars,
|
||||||
|
)
|
||||||
|
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||||
|
destination.update(local_state)
|
||||||
|
|
||||||
def forward(self, input_: Tensor) -> Tensor:
|
def forward(self, input_: Tensor) -> Tensor:
|
||||||
return layernorm_3d(input_, self.weight, self.bias, self.normalized_shape, self.variance_epsilon,
|
return layernorm_3d(input_, self.weight, self.bias, self.normalized_shape, self.variance_epsilon,
|
||||||
self.input_parallel_mode, self.weight_parallel_mode, self.output_parallel_mode)
|
self.input_parallel_mode, self.weight_parallel_mode, self.output_parallel_mode)
|
||||||
|
@ -135,6 +199,122 @@ class Linear3D(ParallelLayer):
|
||||||
broadcast(self.bias, weight_src_rank, self.weight_parallel_mode)
|
broadcast(self.bias, weight_src_rank, self.weight_parallel_mode)
|
||||||
broadcast(self.bias, output_src_rank, self.output_parallel_mode)
|
broadcast(self.bias, output_src_rank, self.output_parallel_mode)
|
||||||
|
|
||||||
|
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||||
|
local_state = OrderedDict()
|
||||||
|
weight_key = prefix + 'weight'
|
||||||
|
bias_key = prefix + 'bias'
|
||||||
|
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||||
|
# weight
|
||||||
|
weight = state_dict.pop(weight_key, None)
|
||||||
|
if weight is not None:
|
||||||
|
local_state[weight_key] = weight.transpose(0, 1)
|
||||||
|
# bias
|
||||||
|
if self.bias is not None:
|
||||||
|
bias = state_dict.pop(bias_key, None)
|
||||||
|
if bias is not None:
|
||||||
|
local_state[bias_key] = bias
|
||||||
|
|
||||||
|
# partition in output groups
|
||||||
|
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
|
||||||
|
gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||||
|
local_state = partition_tensor_parallel_state_dict(
|
||||||
|
local_state,
|
||||||
|
self.output_parallel_mode,
|
||||||
|
dims={
|
||||||
|
weight_key: 0,
|
||||||
|
bias_key: 0
|
||||||
|
},
|
||||||
|
partition_states={
|
||||||
|
weight_key: True,
|
||||||
|
bias_key: False
|
||||||
|
},
|
||||||
|
)
|
||||||
|
# partition in input groups
|
||||||
|
if gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||||
|
local_state = partition_tensor_parallel_state_dict(
|
||||||
|
local_state,
|
||||||
|
self.input_parallel_mode,
|
||||||
|
dims={
|
||||||
|
weight_key: -1,
|
||||||
|
bias_key: 0
|
||||||
|
},
|
||||||
|
partition_states={
|
||||||
|
weight_key: True,
|
||||||
|
bias_key: True
|
||||||
|
},
|
||||||
|
)
|
||||||
|
# partition in weight groups
|
||||||
|
local_state = partition_tensor_parallel_state_dict(
|
||||||
|
local_state,
|
||||||
|
self.weight_parallel_mode,
|
||||||
|
dims={
|
||||||
|
weight_key: -1,
|
||||||
|
bias_key: 0
|
||||||
|
},
|
||||||
|
partition_states={
|
||||||
|
weight_key: True,
|
||||||
|
bias_key: False
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
|
||||||
|
|
||||||
|
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||||
|
weight_key = prefix + 'weight'
|
||||||
|
bias_key = prefix + 'bias'
|
||||||
|
local_state = OrderedDict({weight_key: self.weight})
|
||||||
|
if self.bias is not None:
|
||||||
|
local_state[bias_key] = self.bias
|
||||||
|
|
||||||
|
# gather in weight groups
|
||||||
|
local_state = gather_tensor_parallel_state_dict(
|
||||||
|
local_state,
|
||||||
|
self.weight_parallel_mode,
|
||||||
|
dims={
|
||||||
|
weight_key: -1,
|
||||||
|
bias_key: 0
|
||||||
|
},
|
||||||
|
partition_states={
|
||||||
|
weight_key: True,
|
||||||
|
bias_key: False
|
||||||
|
},
|
||||||
|
keep_vars=keep_vars,
|
||||||
|
)
|
||||||
|
# gather in input groups
|
||||||
|
if gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||||
|
local_state = gather_tensor_parallel_state_dict(
|
||||||
|
local_state,
|
||||||
|
self.input_parallel_mode,
|
||||||
|
dims={
|
||||||
|
weight_key: -1,
|
||||||
|
bias_key: 0
|
||||||
|
},
|
||||||
|
partition_states={
|
||||||
|
weight_key: True,
|
||||||
|
bias_key: True
|
||||||
|
},
|
||||||
|
keep_vars=keep_vars,
|
||||||
|
)
|
||||||
|
# gather in output groups
|
||||||
|
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
|
||||||
|
gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||||
|
local_state = gather_tensor_parallel_state_dict(
|
||||||
|
local_state,
|
||||||
|
self.output_parallel_mode,
|
||||||
|
dims={
|
||||||
|
weight_key: 0,
|
||||||
|
bias_key: 0
|
||||||
|
},
|
||||||
|
partition_states={
|
||||||
|
weight_key: True,
|
||||||
|
bias_key: False
|
||||||
|
},
|
||||||
|
keep_vars=keep_vars,
|
||||||
|
)
|
||||||
|
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||||
|
local_state[weight_key] = local_state[weight_key].transpose(0, 1)
|
||||||
|
destination.update(local_state)
|
||||||
|
|
||||||
def forward(self, input_: Tensor) -> Tensor:
|
def forward(self, input_: Tensor) -> Tensor:
|
||||||
return linear_3d(input_, self.weight, self.bias, self.input_parallel_mode, self.weight_parallel_mode,
|
return linear_3d(input_, self.weight, self.bias, self.input_parallel_mode, self.weight_parallel_mode,
|
||||||
self.output_parallel_mode)
|
self.output_parallel_mode)
|
||||||
|
@ -212,6 +392,73 @@ class Classifier3D(ParallelLayer):
|
||||||
broadcast(self.bias, output_src_rank, self.output_parallel_mode)
|
broadcast(self.bias, output_src_rank, self.output_parallel_mode)
|
||||||
broadcast(self.bias, input_src_rank, self.input_parallel_mode)
|
broadcast(self.bias, input_src_rank, self.input_parallel_mode)
|
||||||
|
|
||||||
|
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||||
|
local_state = OrderedDict()
|
||||||
|
weight_key = prefix + 'weight'
|
||||||
|
bias_key = prefix + 'bias'
|
||||||
|
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||||
|
# weight
|
||||||
|
if self.has_weight:
|
||||||
|
weight = state_dict.pop(weight_key, None)
|
||||||
|
if weight is not None:
|
||||||
|
local_state[weight_key] = weight
|
||||||
|
# bias
|
||||||
|
if self.bias is not None:
|
||||||
|
bias = state_dict.pop(bias_key, None)
|
||||||
|
if bias is not None:
|
||||||
|
local_state[bias_key] = bias
|
||||||
|
|
||||||
|
# partition in output groups
|
||||||
|
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
|
||||||
|
gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||||
|
local_state = partition_tensor_parallel_state_dict(
|
||||||
|
local_state,
|
||||||
|
self.output_parallel_mode,
|
||||||
|
dims={
|
||||||
|
weight_key: -1,
|
||||||
|
bias_key: 0
|
||||||
|
},
|
||||||
|
partition_states={
|
||||||
|
weight_key: True,
|
||||||
|
bias_key: False
|
||||||
|
},
|
||||||
|
)
|
||||||
|
# broadcast in input groups
|
||||||
|
if gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||||
|
local_state = broadcast_state_dict(local_state, self.input_parallel_mode)
|
||||||
|
# broadcast in weight groups
|
||||||
|
local_state = broadcast_state_dict(local_state, self.weight_parallel_mode)
|
||||||
|
|
||||||
|
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
|
||||||
|
|
||||||
|
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||||
|
weight_key = prefix + 'weight'
|
||||||
|
bias_key = prefix + 'bias'
|
||||||
|
local_state = OrderedDict()
|
||||||
|
if self.has_weight:
|
||||||
|
local_state[weight_key] = self.weight
|
||||||
|
if self.bias is not None:
|
||||||
|
local_state[bias_key] = self.bias
|
||||||
|
|
||||||
|
# gather in output groups
|
||||||
|
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
|
||||||
|
gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||||
|
local_state = gather_tensor_parallel_state_dict(
|
||||||
|
local_state,
|
||||||
|
self.output_parallel_mode,
|
||||||
|
dims={
|
||||||
|
weight_key: -1,
|
||||||
|
bias_key: 0
|
||||||
|
},
|
||||||
|
partition_states={
|
||||||
|
weight_key: True,
|
||||||
|
bias_key: False
|
||||||
|
},
|
||||||
|
keep_vars=keep_vars,
|
||||||
|
)
|
||||||
|
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||||
|
destination.update(local_state)
|
||||||
|
|
||||||
def forward(self, input_: Tensor) -> Tensor:
|
def forward(self, input_: Tensor) -> Tensor:
|
||||||
return classifier_3d(input_, self.weight, self.bias, self.input_parallel_mode, self.weight_parallel_mode,
|
return classifier_3d(input_, self.weight, self.bias, self.input_parallel_mode, self.weight_parallel_mode,
|
||||||
self.output_parallel_mode)
|
self.output_parallel_mode)
|
||||||
|
@ -296,6 +543,122 @@ class VocabParallelClassifier3D(ParallelLayer):
|
||||||
broadcast(self.bias, weight_src_rank, self.weight_parallel_mode)
|
broadcast(self.bias, weight_src_rank, self.weight_parallel_mode)
|
||||||
broadcast(self.bias, output_src_rank, self.output_parallel_mode)
|
broadcast(self.bias, output_src_rank, self.output_parallel_mode)
|
||||||
|
|
||||||
|
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||||
|
local_state = OrderedDict()
|
||||||
|
weight_key = prefix + 'weight'
|
||||||
|
bias_key = prefix + 'bias'
|
||||||
|
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||||
|
# weight
|
||||||
|
if self.has_weight:
|
||||||
|
weight = state_dict.pop(weight_key, None)
|
||||||
|
if weight is not None:
|
||||||
|
local_state[weight_key] = weight
|
||||||
|
# bias
|
||||||
|
if self.bias is not None:
|
||||||
|
bias = state_dict.pop(bias_key, None)
|
||||||
|
if bias is not None:
|
||||||
|
local_state[bias_key] = bias
|
||||||
|
|
||||||
|
# partition in output groups
|
||||||
|
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
|
||||||
|
gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||||
|
local_state = partition_tensor_parallel_state_dict(
|
||||||
|
local_state,
|
||||||
|
self.output_parallel_mode,
|
||||||
|
dims={
|
||||||
|
weight_key: -1,
|
||||||
|
bias_key: 0
|
||||||
|
},
|
||||||
|
partition_states={
|
||||||
|
weight_key: True,
|
||||||
|
bias_key: False
|
||||||
|
},
|
||||||
|
)
|
||||||
|
# partition in input groups
|
||||||
|
if gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||||
|
local_state = partition_tensor_parallel_state_dict(
|
||||||
|
local_state,
|
||||||
|
self.input_parallel_mode,
|
||||||
|
dims={
|
||||||
|
weight_key: 0,
|
||||||
|
bias_key: 0
|
||||||
|
},
|
||||||
|
partition_states={
|
||||||
|
weight_key: True,
|
||||||
|
bias_key: True
|
||||||
|
},
|
||||||
|
)
|
||||||
|
# partition in weight groups
|
||||||
|
local_state = partition_tensor_parallel_state_dict(
|
||||||
|
local_state,
|
||||||
|
self.weight_parallel_mode,
|
||||||
|
dims={
|
||||||
|
weight_key: 0,
|
||||||
|
bias_key: 0
|
||||||
|
},
|
||||||
|
partition_states={
|
||||||
|
weight_key: True,
|
||||||
|
bias_key: False
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
|
||||||
|
|
||||||
|
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||||
|
weight_key = prefix + 'weight'
|
||||||
|
bias_key = prefix + 'bias'
|
||||||
|
local_state = OrderedDict({weight_key: self.weight})
|
||||||
|
if self.bias is not None:
|
||||||
|
local_state[bias_key] = self.bias
|
||||||
|
|
||||||
|
# gather in weight groups
|
||||||
|
local_state = gather_tensor_parallel_state_dict(
|
||||||
|
local_state,
|
||||||
|
self.weight_parallel_mode,
|
||||||
|
dims={
|
||||||
|
weight_key: 0,
|
||||||
|
bias_key: 0
|
||||||
|
},
|
||||||
|
partition_states={
|
||||||
|
weight_key: True,
|
||||||
|
bias_key: False
|
||||||
|
},
|
||||||
|
keep_vars=keep_vars,
|
||||||
|
)
|
||||||
|
# gather in input groups
|
||||||
|
if gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||||
|
local_state = gather_tensor_parallel_state_dict(
|
||||||
|
local_state,
|
||||||
|
self.input_parallel_mode,
|
||||||
|
dims={
|
||||||
|
weight_key: 0,
|
||||||
|
bias_key: 0
|
||||||
|
},
|
||||||
|
partition_states={
|
||||||
|
weight_key: True,
|
||||||
|
bias_key: True
|
||||||
|
},
|
||||||
|
keep_vars=keep_vars,
|
||||||
|
)
|
||||||
|
# gather in output groups
|
||||||
|
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
|
||||||
|
gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||||
|
local_state = gather_tensor_parallel_state_dict(
|
||||||
|
local_state,
|
||||||
|
self.output_parallel_mode,
|
||||||
|
dims={
|
||||||
|
weight_key: -1,
|
||||||
|
bias_key: 0
|
||||||
|
},
|
||||||
|
partition_states={
|
||||||
|
weight_key: True,
|
||||||
|
bias_key: False
|
||||||
|
},
|
||||||
|
keep_vars=keep_vars,
|
||||||
|
)
|
||||||
|
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||||
|
destination.update(local_state)
|
||||||
|
|
||||||
def forward(self, input_: Tensor) -> Tensor:
|
def forward(self, input_: Tensor) -> Tensor:
|
||||||
return linear_3d(input_, self.weight.transpose(0, 1), self.bias, self.input_parallel_mode,
|
return linear_3d(input_, self.weight.transpose(0, 1), self.bias, self.input_parallel_mode,
|
||||||
self.weight_parallel_mode, self.output_parallel_mode)
|
self.weight_parallel_mode, self.output_parallel_mode)
|
||||||
|
@ -392,12 +755,98 @@ class PatchEmbedding3D(ParallelLayer):
|
||||||
self.cls_token.register_hook(self._sync_grad_hook)
|
self.cls_token.register_hook(self._sync_grad_hook)
|
||||||
self.pos_embed.register_hook(self._sync_grad_hook)
|
self.pos_embed.register_hook(self._sync_grad_hook)
|
||||||
|
|
||||||
|
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||||
|
local_state = OrderedDict()
|
||||||
|
weight_key = prefix + 'weight'
|
||||||
|
bias_key = prefix + 'bias'
|
||||||
|
cls_token_key = prefix + 'cls_token'
|
||||||
|
pos_embed_key = prefix + 'pos_embed'
|
||||||
|
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||||
|
# weight
|
||||||
|
weight = state_dict.pop(weight_key, None)
|
||||||
|
if weight is not None:
|
||||||
|
local_state[weight_key] = weight
|
||||||
|
# bias
|
||||||
|
bias = state_dict.pop(bias_key, None)
|
||||||
|
if bias is not None:
|
||||||
|
local_state[bias_key] = bias
|
||||||
|
# cls token
|
||||||
|
cls_token = state_dict.pop(cls_token_key, None)
|
||||||
|
if cls_token is not None:
|
||||||
|
local_state[cls_token_key] = cls_token
|
||||||
|
# pos embed
|
||||||
|
pos_embed = state_dict.pop(pos_embed_key, None)
|
||||||
|
if pos_embed is not None:
|
||||||
|
local_state[pos_embed_key] = pos_embed
|
||||||
|
|
||||||
|
# partition in output groups
|
||||||
|
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
|
||||||
|
gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||||
|
local_state = partition_tensor_parallel_state_dict(
|
||||||
|
local_state,
|
||||||
|
self.output_parallel_mode,
|
||||||
|
dims={
|
||||||
|
weight_key: 0,
|
||||||
|
bias_key: 0,
|
||||||
|
cls_token_key: -1,
|
||||||
|
pos_embed_key: -1
|
||||||
|
},
|
||||||
|
partition_states={
|
||||||
|
weight_key: True,
|
||||||
|
bias_key: True,
|
||||||
|
cls_token_key: True,
|
||||||
|
pos_embed_key: True
|
||||||
|
},
|
||||||
|
)
|
||||||
|
# broadcast in input groups
|
||||||
|
if gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||||
|
local_state = broadcast_state_dict(local_state, self.input_parallel_mode)
|
||||||
|
# broadcast in weight groups
|
||||||
|
local_state = broadcast_state_dict(local_state, self.weight_parallel_mode)
|
||||||
|
|
||||||
|
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
|
||||||
|
|
||||||
|
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||||
|
weight_key = prefix + 'weight'
|
||||||
|
bias_key = prefix + 'bias'
|
||||||
|
cls_token_key = prefix + 'cls_token'
|
||||||
|
pos_embed_key = prefix + 'pos_embed'
|
||||||
|
local_state = OrderedDict({
|
||||||
|
weight_key: self.weight,
|
||||||
|
bias_key: self.bias,
|
||||||
|
cls_token_key: self.cls_token,
|
||||||
|
pos_embed_key: self.pos_embed
|
||||||
|
})
|
||||||
|
|
||||||
|
# gather in output groups
|
||||||
|
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
|
||||||
|
gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||||
|
local_state = gather_tensor_parallel_state_dict(
|
||||||
|
local_state,
|
||||||
|
self.output_parallel_mode,
|
||||||
|
dims={
|
||||||
|
weight_key: 0,
|
||||||
|
bias_key: 0,
|
||||||
|
cls_token_key: -1,
|
||||||
|
pos_embed_key: -1
|
||||||
|
},
|
||||||
|
partition_states={
|
||||||
|
weight_key: True,
|
||||||
|
bias_key: True,
|
||||||
|
cls_token_key: True,
|
||||||
|
pos_embed_key: True
|
||||||
|
},
|
||||||
|
keep_vars=keep_vars,
|
||||||
|
)
|
||||||
|
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||||
|
destination.update(local_state)
|
||||||
|
|
||||||
def forward(self, input_: Tensor) -> Tensor:
|
def forward(self, input_: Tensor) -> Tensor:
|
||||||
input_ = split_tensor_3d(input_, 0, self.weight_parallel_mode)
|
input_ = split_tensor_3d(input_, 0, self.weight_parallel_mode)
|
||||||
input_ = split_tensor_3d(input_, 0, self.input_parallel_mode)
|
input_ = split_tensor_3d(input_, 0, self.input_parallel_mode)
|
||||||
output = F.conv2d(input_, self.weight, self.bias, stride=self.patch_size)
|
output = F.conv2d(input_, self.weight, self.bias, stride=self.patch_size)
|
||||||
if self.flatten:
|
if self.flatten:
|
||||||
output = output.flatten(2).transpose(1, 2) # BCHW -> BNC
|
output = output.flatten(2).transpose(1, 2) # BCHW -> BNC
|
||||||
|
|
||||||
cls_token = self.cls_token.expand(output.shape[0], -1, -1)
|
cls_token = self.cls_token.expand(output.shape[0], -1, -1)
|
||||||
output = torch.cat((cls_token, output), dim=1)
|
output = torch.cat((cls_token, output), dim=1)
|
||||||
|
@ -480,6 +929,49 @@ class Embedding3D(ParallelLayer):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
self.weight[self.padding_idx].fill_(0)
|
self.weight[self.padding_idx].fill_(0)
|
||||||
|
|
||||||
|
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||||
|
local_state = OrderedDict()
|
||||||
|
weight_key = prefix + 'weight'
|
||||||
|
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||||
|
# weight
|
||||||
|
weight = state_dict.pop(weight_key, None)
|
||||||
|
if weight is not None:
|
||||||
|
local_state[weight_key] = weight
|
||||||
|
|
||||||
|
# partition in output groups
|
||||||
|
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
|
||||||
|
gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||||
|
local_state = partition_tensor_parallel_state_dict(
|
||||||
|
local_state,
|
||||||
|
self.output_parallel_mode,
|
||||||
|
dims={weight_key: 0},
|
||||||
|
partition_states={weight_key: True},
|
||||||
|
)
|
||||||
|
# broadcast in input groups
|
||||||
|
if gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||||
|
local_state = broadcast_state_dict(local_state, self.input_parallel_mode)
|
||||||
|
# broadcast in weight groups
|
||||||
|
local_state = broadcast_state_dict(local_state, self.weight_parallel_mode)
|
||||||
|
|
||||||
|
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
|
||||||
|
|
||||||
|
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||||
|
weight_key = prefix + 'weight'
|
||||||
|
local_state = OrderedDict({weight_key: self.weight})
|
||||||
|
|
||||||
|
# gather in output groups
|
||||||
|
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
|
||||||
|
gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||||
|
local_state = gather_tensor_parallel_state_dict(
|
||||||
|
local_state,
|
||||||
|
self.output_parallel_mode,
|
||||||
|
dims={weight_key: 0},
|
||||||
|
partition_states={weight_key: True},
|
||||||
|
keep_vars=keep_vars,
|
||||||
|
)
|
||||||
|
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||||
|
destination.update(local_state)
|
||||||
|
|
||||||
def forward(self, input_: Tensor) -> Tensor:
|
def forward(self, input_: Tensor) -> Tensor:
|
||||||
input_ = split_tensor_3d(input_, 0, self.weight_parallel_mode)
|
input_ = split_tensor_3d(input_, 0, self.weight_parallel_mode)
|
||||||
input_ = split_tensor_3d(input_, 0, self.input_parallel_mode)
|
input_ = split_tensor_3d(input_, 0, self.input_parallel_mode)
|
||||||
|
@ -570,6 +1062,76 @@ class VocabParallelEmbedding3D(torch.nn.Module):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
self.weight[self.padding_idx - self.vocab_start_index].fill_(0)
|
self.weight[self.padding_idx - self.vocab_start_index].fill_(0)
|
||||||
|
|
||||||
|
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||||
|
local_state = OrderedDict()
|
||||||
|
weight_key = prefix + 'weight'
|
||||||
|
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||||
|
# weight
|
||||||
|
weight = state_dict.pop(weight_key, None)
|
||||||
|
if weight is not None:
|
||||||
|
local_state[weight_key] = weight
|
||||||
|
|
||||||
|
# partition in output groups
|
||||||
|
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
|
||||||
|
gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||||
|
local_state = partition_tensor_parallel_state_dict(
|
||||||
|
local_state,
|
||||||
|
self.output_parallel_mode,
|
||||||
|
dims={weight_key: -1},
|
||||||
|
partition_states={weight_key: True},
|
||||||
|
)
|
||||||
|
# partition in input groups
|
||||||
|
if gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||||
|
local_state = partition_tensor_parallel_state_dict(
|
||||||
|
local_state,
|
||||||
|
self.input_parallel_mode,
|
||||||
|
dims={weight_key: 0},
|
||||||
|
partition_states={weight_key: True},
|
||||||
|
)
|
||||||
|
# partition in weight groups
|
||||||
|
local_state = partition_tensor_parallel_state_dict(
|
||||||
|
local_state,
|
||||||
|
self.weight_parallel_mode,
|
||||||
|
dims={weight_key: 0},
|
||||||
|
partition_states={weight_key: True},
|
||||||
|
)
|
||||||
|
|
||||||
|
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
|
||||||
|
|
||||||
|
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||||
|
weight_key = prefix + 'weight'
|
||||||
|
local_state = OrderedDict({weight_key: self.weight})
|
||||||
|
|
||||||
|
# gather in weight groups
|
||||||
|
local_state = gather_tensor_parallel_state_dict(
|
||||||
|
local_state,
|
||||||
|
self.weight_parallel_mode,
|
||||||
|
dims={weight_key: 0},
|
||||||
|
partition_states={weight_key: True},
|
||||||
|
keep_vars=keep_vars,
|
||||||
|
)
|
||||||
|
# gather in input groups
|
||||||
|
if gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||||
|
local_state = gather_tensor_parallel_state_dict(
|
||||||
|
local_state,
|
||||||
|
self.input_parallel_mode,
|
||||||
|
dims={weight_key: 0},
|
||||||
|
partition_states={weight_key: True},
|
||||||
|
keep_vars=keep_vars,
|
||||||
|
)
|
||||||
|
# gather in output groups
|
||||||
|
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
|
||||||
|
gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||||
|
local_state = gather_tensor_parallel_state_dict(
|
||||||
|
local_state,
|
||||||
|
self.output_parallel_mode,
|
||||||
|
dims={weight_key: -1},
|
||||||
|
partition_states={weight_key: True},
|
||||||
|
keep_vars=keep_vars,
|
||||||
|
)
|
||||||
|
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||||
|
destination.update(local_state)
|
||||||
|
|
||||||
def forward(self, input_: Tensor) -> Tensor:
|
def forward(self, input_: Tensor) -> Tensor:
|
||||||
input_ = split_tensor_3d(input_, 0, self.weight_parallel_mode)
|
input_ = split_tensor_3d(input_, 0, self.weight_parallel_mode)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue