[tensor] fixed non-serializable colo parameter during model checkpointing (#1153)

pull/1164/head
Frank Lee 2022-06-22 11:43:38 +08:00 committed by GitHub
parent ffa025e120
commit f8eec98ff5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 39 additions and 3 deletions

View File

@ -1,13 +1,13 @@
from .utils import InsertPostInitMethodToModuleSubClasses from .utils import InsertPostInitMethodToModuleSubClasses
import torch import torch
from colossalai.tensor import ColoTensor, ColoParameter from colossalai.tensor import ColoTensor, ColoParameter, distspec, TensorSpec
from colossalai.nn.parallel.layers import register_colo_module, \ from colossalai.nn.parallel.layers import register_colo_module, \
ColoLinear, ColoEmbedding ColoLinear, ColoEmbedding
from copy import copy
from torch import nn from torch import nn
from typing import Iterator, Tuple, Union from typing import Iterator, Tuple, Union
from functools import partialmethod
# find named_params includes replica # find named_params includes replica
@ -34,6 +34,38 @@ def ColoModulize(module):
module._colo_visited = True module._colo_visited = True
def colo_state_dict(self, destination=None, prefix='', keep_vars=False, state_dict_func=None):
# build param to spec mapping
mapping = dict()
# gather all params
has_dist_parameter = False
with torch.no_grad():
for param in self.parameters():
if isinstance(param, ColoParameter) and param.has_spec():
has_dist_parameter = True
mapping[id(param)] = copy(param.spec)
param.set_spec(TensorSpec(distspec.replicate()))
# TODO: fix when keep_vars = True
# when keep_vars = False, the state_dict_func will call detach to create
# new tensors, but when keep_vars = True, the recovery of spec will be reflected
# in the `ret`, such that the final state dict will still contain process group,
# raising exception as it is not serializable
assert not (keep_vars and has_dist_parameter), 'keep_vars cannot be True when there are distributed ColoParameters.'
ret = state_dict_func(self, destination, prefix, keep_vars)
# recover
with torch.no_grad():
for param in self.parameters():
param_id = id(param)
if param_id in mapping:
spec = mapping[id(param)]
param.set_spec(spec)
return ret
class ColoInitContext(InsertPostInitMethodToModuleSubClasses): class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
def __init__(self, lazy_memory_allocate: bool = False, device: torch.device = torch.device('cpu')): def __init__(self, lazy_memory_allocate: bool = False, device: torch.device = torch.device('cpu')):
@ -52,6 +84,10 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
register_colo_module(torch.nn.Linear, ColoLinear()) register_colo_module(torch.nn.Linear, ColoLinear())
register_colo_module(torch.nn.Embedding, ColoEmbedding()) register_colo_module(torch.nn.Embedding, ColoEmbedding())
def _pre_context_exec(self):
self.state_dict_func = nn.Module.state_dict
nn.Module.state_dict = partialmethod(colo_state_dict, state_dict_func=self.state_dict_func)
def _post_init_method(self, module: torch.nn.Module, *args, **kwargs): def _post_init_method(self, module: torch.nn.Module, *args, **kwargs):
""" """
The function to call at the end of the constructor of each module. The function to call at the end of the constructor of each module.