diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index 1ffbae73e..65051a61f 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -18,7 +18,7 @@ from colossalai.logging import get_dist_logger from .utils import RingComm, get_half_index, split_varlen_zigzag -MEMORY_BOUND = 1 * 1e9 +MEMORY_BOUND = 10 * 1e9 __all__ = [ "AttnMaskType",