|
|
@ -1,3 +1,4 @@
|
|
|
|
|
|
|
|
import warnings
|
|
|
|
from functools import partial
|
|
|
|
from functools import partial
|
|
|
|
from typing import Callable, Dict, List
|
|
|
|
from typing import Callable, Dict, List
|
|
|
|
|
|
|
|
|
|
|
@ -39,6 +40,9 @@ class OPTPolicy(Policy):
|
|
|
|
from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoder, OPTDecoderLayer
|
|
|
|
from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoder, OPTDecoderLayer
|
|
|
|
|
|
|
|
|
|
|
|
policy = {}
|
|
|
|
policy = {}
|
|
|
|
|
|
|
|
if self.shard_config.enable_sequence_parallelism:
|
|
|
|
|
|
|
|
self.shard_config.enable_sequence_parallelism = False
|
|
|
|
|
|
|
|
warnings.warn("OPT dosen't support sequence parallelism now, will ignore the sequence parallelism flag.")
|
|
|
|
|
|
|
|
|
|
|
|
if self.shard_config.enable_tensor_parallelism:
|
|
|
|
if self.shard_config.enable_tensor_parallelism:
|
|
|
|
policy[OPTDecoder] = ModulePolicyDescription(sub_module_replacement=[
|
|
|
|
policy[OPTDecoder] = ModulePolicyDescription(sub_module_replacement=[
|
|
|
|