fix import error in sharded model v2 (#1053)

pull/1063/head
ver217 2022-06-02 13:48:22 +08:00 committed by GitHub
parent e1922ea4f6
commit e3fde4ee6b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 5 additions and 1 deletions

View File

@ -2,7 +2,6 @@ import functools
from collections import OrderedDict from collections import OrderedDict
from typing import Any, Optional, Iterator, Tuple from typing import Any, Optional, Iterator, Tuple
from copy import deepcopy from copy import deepcopy
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX
import itertools import itertools
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -28,6 +27,11 @@ from colossalai.gemini.tensor_placement_policy import TensorPlacementPolicyFacto
from ._utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad, free_storage, from ._utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad, free_storage,
get_gradient_predivide_factor) get_gradient_predivide_factor)
try:
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX
except ImportError:
_EXTRA_STATE_KEY_SUFFIX = '_extra_state'
class ShardedModelV2(nn.Module): class ShardedModelV2(nn.Module):
""" """