mirror of https://github.com/hpcaitech/ColossalAI
fix import error in sharded model v2 (#1053)
parent
e1922ea4f6
commit
e3fde4ee6b
|
@ -2,7 +2,6 @@ import functools
|
|||
from collections import OrderedDict
|
||||
from typing import Any, Optional, Iterator, Tuple
|
||||
from copy import deepcopy
|
||||
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX
|
||||
import itertools
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
@ -28,6 +27,11 @@ from colossalai.gemini.tensor_placement_policy import TensorPlacementPolicyFacto
|
|||
from ._utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad, free_storage,
|
||||
get_gradient_predivide_factor)
|
||||
|
||||
try:
|
||||
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX
|
||||
except ImportError:
|
||||
_EXTRA_STATE_KEY_SUFFIX = '_extra_state'
|
||||
|
||||
|
||||
class ShardedModelV2(nn.Module):
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue