[model checkpoint] updated saving/loading for 3d layers (#597)

pull/625/head
アマデウス 2022-04-01 16:52:47 +08:00 committed by GitHub
parent 93089ed708
commit 77ad24bf94
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 565 additions and 3 deletions

View File

@ -1,4 +1,5 @@
import math
from collections import OrderedDict
from typing import Callable
import torch
@ -12,13 +13,15 @@ from colossalai.global_variables import tensor_parallel_env as env
from colossalai.nn import init as init
from colossalai.nn.layer.base_layer import ParallelLayer
from colossalai.registry import LAYERS
from colossalai.utils.checkpointing import (broadcast_state_dict, gather_tensor_parallel_state_dict,
partition_tensor_parallel_state_dict)
from colossalai.utils.cuda import get_current_device
from torch import Tensor
from torch.nn import Parameter
from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
from ._operation import layernorm_3d, linear_3d, classifier_3d, split_tensor_3d
from ._operation import all_gather_tensor_3d, reduce_scatter_tensor_3d, broadcast_weight_3d_from_diagonal
from ._operation import (all_gather_tensor_3d, broadcast_weight_3d_from_diagonal, classifier_3d, layernorm_3d,
linear_3d, reduce_scatter_tensor_3d, split_tensor_3d)
from ._utils import get_depth_from_env, get_last_group, get_parallel_mode_from_env, swap_in_out_group
@ -61,6 +64,67 @@ class LayerNorm3D(ParallelLayer):
init.zeros_()(self.bias)
init.ones_()(self.weight)
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict()
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
# weight
weight = state_dict.pop(weight_key, None)
if weight is not None:
local_state[weight_key] = weight.transpose(0, 1)
# bias
bias = state_dict.pop(bias_key, None)
if bias is not None:
local_state[bias_key] = bias
# partition in output groups
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
gpc.get_local_rank(self.weight_parallel_mode) == 0:
local_state = partition_tensor_parallel_state_dict(
local_state,
self.output_parallel_mode,
dims={
weight_key: 0,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: True,
},
)
# broadcast in input groups
if gpc.get_local_rank(self.weight_parallel_mode) == 0:
local_state = broadcast_state_dict(local_state, self.input_parallel_mode)
# broadcast in weight groups
local_state = broadcast_state_dict(local_state, self.weight_parallel_mode)
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
def _save_to_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
local_state = OrderedDict({weight_key: self.weight, bias_key: self.bias})
# gather in output groups
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
gpc.get_local_rank(self.weight_parallel_mode) == 0:
local_state = gather_tensor_parallel_state_dict(
local_state,
self.output_parallel_mode,
dims={
weight_key: 0,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: True
},
keep_vars=keep_vars,
)
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
destination.update(local_state)
def forward(self, input_: Tensor) -> Tensor:
return layernorm_3d(input_, self.weight, self.bias, self.normalized_shape, self.variance_epsilon,
self.input_parallel_mode, self.weight_parallel_mode, self.output_parallel_mode)
@ -135,6 +199,122 @@ class Linear3D(ParallelLayer):
broadcast(self.bias, weight_src_rank, self.weight_parallel_mode)
broadcast(self.bias, output_src_rank, self.output_parallel_mode)
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict()
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
# weight
weight = state_dict.pop(weight_key, None)
if weight is not None:
local_state[weight_key] = weight.transpose(0, 1)
# bias
if self.bias is not None:
bias = state_dict.pop(bias_key, None)
if bias is not None:
local_state[bias_key] = bias
# partition in output groups
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
gpc.get_local_rank(self.weight_parallel_mode) == 0:
local_state = partition_tensor_parallel_state_dict(
local_state,
self.output_parallel_mode,
dims={
weight_key: 0,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: False
},
)
# partition in input groups
if gpc.get_local_rank(self.weight_parallel_mode) == 0:
local_state = partition_tensor_parallel_state_dict(
local_state,
self.input_parallel_mode,
dims={
weight_key: -1,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: True
},
)
# partition in weight groups
local_state = partition_tensor_parallel_state_dict(
local_state,
self.weight_parallel_mode,
dims={
weight_key: -1,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: False
},
)
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
def _save_to_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
local_state = OrderedDict({weight_key: self.weight})
if self.bias is not None:
local_state[bias_key] = self.bias
# gather in weight groups
local_state = gather_tensor_parallel_state_dict(
local_state,
self.weight_parallel_mode,
dims={
weight_key: -1,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: False
},
keep_vars=keep_vars,
)
# gather in input groups
if gpc.get_local_rank(self.weight_parallel_mode) == 0:
local_state = gather_tensor_parallel_state_dict(
local_state,
self.input_parallel_mode,
dims={
weight_key: -1,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: True
},
keep_vars=keep_vars,
)
# gather in output groups
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
gpc.get_local_rank(self.weight_parallel_mode) == 0:
local_state = gather_tensor_parallel_state_dict(
local_state,
self.output_parallel_mode,
dims={
weight_key: 0,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: False
},
keep_vars=keep_vars,
)
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
local_state[weight_key] = local_state[weight_key].transpose(0, 1)
destination.update(local_state)
def forward(self, input_: Tensor) -> Tensor:
return linear_3d(input_, self.weight, self.bias, self.input_parallel_mode, self.weight_parallel_mode,
self.output_parallel_mode)
@ -212,6 +392,73 @@ class Classifier3D(ParallelLayer):
broadcast(self.bias, output_src_rank, self.output_parallel_mode)
broadcast(self.bias, input_src_rank, self.input_parallel_mode)
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict()
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
# weight
if self.has_weight:
weight = state_dict.pop(weight_key, None)
if weight is not None:
local_state[weight_key] = weight
# bias
if self.bias is not None:
bias = state_dict.pop(bias_key, None)
if bias is not None:
local_state[bias_key] = bias
# partition in output groups
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
gpc.get_local_rank(self.weight_parallel_mode) == 0:
local_state = partition_tensor_parallel_state_dict(
local_state,
self.output_parallel_mode,
dims={
weight_key: -1,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: False
},
)
# broadcast in input groups
if gpc.get_local_rank(self.weight_parallel_mode) == 0:
local_state = broadcast_state_dict(local_state, self.input_parallel_mode)
# broadcast in weight groups
local_state = broadcast_state_dict(local_state, self.weight_parallel_mode)
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
def _save_to_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
local_state = OrderedDict()
if self.has_weight:
local_state[weight_key] = self.weight
if self.bias is not None:
local_state[bias_key] = self.bias
# gather in output groups
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
gpc.get_local_rank(self.weight_parallel_mode) == 0:
local_state = gather_tensor_parallel_state_dict(
local_state,
self.output_parallel_mode,
dims={
weight_key: -1,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: False
},
keep_vars=keep_vars,
)
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
destination.update(local_state)
def forward(self, input_: Tensor) -> Tensor:
return classifier_3d(input_, self.weight, self.bias, self.input_parallel_mode, self.weight_parallel_mode,
self.output_parallel_mode)
@ -296,6 +543,122 @@ class VocabParallelClassifier3D(ParallelLayer):
broadcast(self.bias, weight_src_rank, self.weight_parallel_mode)
broadcast(self.bias, output_src_rank, self.output_parallel_mode)
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict()
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
# weight
if self.has_weight:
weight = state_dict.pop(weight_key, None)
if weight is not None:
local_state[weight_key] = weight
# bias
if self.bias is not None:
bias = state_dict.pop(bias_key, None)
if bias is not None:
local_state[bias_key] = bias
# partition in output groups
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
gpc.get_local_rank(self.weight_parallel_mode) == 0:
local_state = partition_tensor_parallel_state_dict(
local_state,
self.output_parallel_mode,
dims={
weight_key: -1,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: False
},
)
# partition in input groups
if gpc.get_local_rank(self.weight_parallel_mode) == 0:
local_state = partition_tensor_parallel_state_dict(
local_state,
self.input_parallel_mode,
dims={
weight_key: 0,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: True
},
)
# partition in weight groups
local_state = partition_tensor_parallel_state_dict(
local_state,
self.weight_parallel_mode,
dims={
weight_key: 0,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: False
},
)
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
def _save_to_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
local_state = OrderedDict({weight_key: self.weight})
if self.bias is not None:
local_state[bias_key] = self.bias
# gather in weight groups
local_state = gather_tensor_parallel_state_dict(
local_state,
self.weight_parallel_mode,
dims={
weight_key: 0,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: False
},
keep_vars=keep_vars,
)
# gather in input groups
if gpc.get_local_rank(self.weight_parallel_mode) == 0:
local_state = gather_tensor_parallel_state_dict(
local_state,
self.input_parallel_mode,
dims={
weight_key: 0,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: True
},
keep_vars=keep_vars,
)
# gather in output groups
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
gpc.get_local_rank(self.weight_parallel_mode) == 0:
local_state = gather_tensor_parallel_state_dict(
local_state,
self.output_parallel_mode,
dims={
weight_key: -1,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: False
},
keep_vars=keep_vars,
)
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
destination.update(local_state)
def forward(self, input_: Tensor) -> Tensor:
return linear_3d(input_, self.weight.transpose(0, 1), self.bias, self.input_parallel_mode,
self.weight_parallel_mode, self.output_parallel_mode)
@ -392,12 +755,98 @@ class PatchEmbedding3D(ParallelLayer):
self.cls_token.register_hook(self._sync_grad_hook)
self.pos_embed.register_hook(self._sync_grad_hook)
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict()
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
cls_token_key = prefix + 'cls_token'
pos_embed_key = prefix + 'pos_embed'
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
# weight
weight = state_dict.pop(weight_key, None)
if weight is not None:
local_state[weight_key] = weight
# bias
bias = state_dict.pop(bias_key, None)
if bias is not None:
local_state[bias_key] = bias
# cls token
cls_token = state_dict.pop(cls_token_key, None)
if cls_token is not None:
local_state[cls_token_key] = cls_token
# pos embed
pos_embed = state_dict.pop(pos_embed_key, None)
if pos_embed is not None:
local_state[pos_embed_key] = pos_embed
# partition in output groups
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
gpc.get_local_rank(self.weight_parallel_mode) == 0:
local_state = partition_tensor_parallel_state_dict(
local_state,
self.output_parallel_mode,
dims={
weight_key: 0,
bias_key: 0,
cls_token_key: -1,
pos_embed_key: -1
},
partition_states={
weight_key: True,
bias_key: True,
cls_token_key: True,
pos_embed_key: True
},
)
# broadcast in input groups
if gpc.get_local_rank(self.weight_parallel_mode) == 0:
local_state = broadcast_state_dict(local_state, self.input_parallel_mode)
# broadcast in weight groups
local_state = broadcast_state_dict(local_state, self.weight_parallel_mode)
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
def _save_to_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
cls_token_key = prefix + 'cls_token'
pos_embed_key = prefix + 'pos_embed'
local_state = OrderedDict({
weight_key: self.weight,
bias_key: self.bias,
cls_token_key: self.cls_token,
pos_embed_key: self.pos_embed
})
# gather in output groups
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
gpc.get_local_rank(self.weight_parallel_mode) == 0:
local_state = gather_tensor_parallel_state_dict(
local_state,
self.output_parallel_mode,
dims={
weight_key: 0,
bias_key: 0,
cls_token_key: -1,
pos_embed_key: -1
},
partition_states={
weight_key: True,
bias_key: True,
cls_token_key: True,
pos_embed_key: True
},
keep_vars=keep_vars,
)
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
destination.update(local_state)
def forward(self, input_: Tensor) -> Tensor:
input_ = split_tensor_3d(input_, 0, self.weight_parallel_mode)
input_ = split_tensor_3d(input_, 0, self.input_parallel_mode)
output = F.conv2d(input_, self.weight, self.bias, stride=self.patch_size)
if self.flatten:
output = output.flatten(2).transpose(1, 2) # BCHW -> BNC
output = output.flatten(2).transpose(1, 2) # BCHW -> BNC
cls_token = self.cls_token.expand(output.shape[0], -1, -1)
output = torch.cat((cls_token, output), dim=1)
@ -480,6 +929,49 @@ class Embedding3D(ParallelLayer):
with torch.no_grad():
self.weight[self.padding_idx].fill_(0)
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict()
weight_key = prefix + 'weight'
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
# weight
weight = state_dict.pop(weight_key, None)
if weight is not None:
local_state[weight_key] = weight
# partition in output groups
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
gpc.get_local_rank(self.weight_parallel_mode) == 0:
local_state = partition_tensor_parallel_state_dict(
local_state,
self.output_parallel_mode,
dims={weight_key: 0},
partition_states={weight_key: True},
)
# broadcast in input groups
if gpc.get_local_rank(self.weight_parallel_mode) == 0:
local_state = broadcast_state_dict(local_state, self.input_parallel_mode)
# broadcast in weight groups
local_state = broadcast_state_dict(local_state, self.weight_parallel_mode)
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
def _save_to_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight'
local_state = OrderedDict({weight_key: self.weight})
# gather in output groups
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
gpc.get_local_rank(self.weight_parallel_mode) == 0:
local_state = gather_tensor_parallel_state_dict(
local_state,
self.output_parallel_mode,
dims={weight_key: 0},
partition_states={weight_key: True},
keep_vars=keep_vars,
)
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
destination.update(local_state)
def forward(self, input_: Tensor) -> Tensor:
input_ = split_tensor_3d(input_, 0, self.weight_parallel_mode)
input_ = split_tensor_3d(input_, 0, self.input_parallel_mode)
@ -570,6 +1062,76 @@ class VocabParallelEmbedding3D(torch.nn.Module):
with torch.no_grad():
self.weight[self.padding_idx - self.vocab_start_index].fill_(0)
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict()
weight_key = prefix + 'weight'
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
# weight
weight = state_dict.pop(weight_key, None)
if weight is not None:
local_state[weight_key] = weight
# partition in output groups
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
gpc.get_local_rank(self.weight_parallel_mode) == 0:
local_state = partition_tensor_parallel_state_dict(
local_state,
self.output_parallel_mode,
dims={weight_key: -1},
partition_states={weight_key: True},
)
# partition in input groups
if gpc.get_local_rank(self.weight_parallel_mode) == 0:
local_state = partition_tensor_parallel_state_dict(
local_state,
self.input_parallel_mode,
dims={weight_key: 0},
partition_states={weight_key: True},
)
# partition in weight groups
local_state = partition_tensor_parallel_state_dict(
local_state,
self.weight_parallel_mode,
dims={weight_key: 0},
partition_states={weight_key: True},
)
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
def _save_to_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight'
local_state = OrderedDict({weight_key: self.weight})
# gather in weight groups
local_state = gather_tensor_parallel_state_dict(
local_state,
self.weight_parallel_mode,
dims={weight_key: 0},
partition_states={weight_key: True},
keep_vars=keep_vars,
)
# gather in input groups
if gpc.get_local_rank(self.weight_parallel_mode) == 0:
local_state = gather_tensor_parallel_state_dict(
local_state,
self.input_parallel_mode,
dims={weight_key: 0},
partition_states={weight_key: True},
keep_vars=keep_vars,
)
# gather in output groups
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
gpc.get_local_rank(self.weight_parallel_mode) == 0:
local_state = gather_tensor_parallel_state_dict(
local_state,
self.output_parallel_mode,
dims={weight_key: -1},
partition_states={weight_key: True},
keep_vars=keep_vars,
)
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
destination.update(local_state)
def forward(self, input_: Tensor) -> Tensor:
input_ = split_tensor_3d(input_, 0, self.weight_parallel_mode)