[example] update examples related to zero/gemini (#3431)

* [zero] update legacy import

* [zero] update examples

* [example] fix opt tutorial

* [example] fix opt tutorial

* [example] fix opt tutorial

* [example] fix opt tutorial

* [example] fix import
pull/3445/head
ver217 2023-04-04 17:32:51 +08:00 committed by GitHub
parent 773955abfa
commit 573af84184
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 50 additions and 6 deletions

View File

@ -6,6 +6,7 @@ import torch.nn as nn
from colossalai.logging import get_dist_logger
from .init_ctx import ZeroInitContext, no_shard_zero_context, no_shard_zero_decrator
from .shard_utils import BucketTensorShardStrategy, TensorShardStrategy
from .sharded_model import ShardedModelV2
from .sharded_optim import ShardedOptimizerV2
@ -40,5 +41,5 @@ def convert_to_zero_v2(model: nn.Module, optimizer: torch.optim.Optimizer, model
__all__ = [
'convert_to_zero_v2', 'ShardedModelV2', 'ShardedOptimizerV2', 'ZeroInitContext', 'no_shard_zero_context',
'no_shard_zero_decrator'
'no_shard_zero_decrator', 'TensorShardStrategy', 'BucketTensorShardStrategy'
]

View File

@ -1,4 +1,9 @@
from colossalai.zero.shard_utils import TensorShardStrategy
from colossalai.nn.optimizer import FusedAdam
try:
from colossalai.zero.shard_utils import TensorShardStrategy
except ImportError:
# colossalai > 0.2.8
from colossalai.zero.legacy import TensorShardStrategy
clip_grad_norm = 1.0

View File

@ -1,6 +1,11 @@
from colossalai.zero.shard_utils import TensorShardStrategy
from colossalai.nn.optimizer import FusedAdam
try:
from colossalai.zero.shard_utils import TensorShardStrategy
except ImportError:
# colossalai > 0.2.8
from colossalai.zero.legacy import TensorShardStrategy
# fp16 = dict(
# mode=AMP_TYPE.TORCH,
# )

View File

@ -1,4 +1,8 @@
try:
from colossalai.zero.shard_utils import TensorShardStrategy
except ImportError:
# colossalai > 0.2.8
from colossalai.zero.legacy import TensorShardStrategy
zero = dict(model_config=dict(shard_strategy=TensorShardStrategy(),
tensor_placement_policy="auto",

View File

@ -4,3 +4,4 @@ datasets >= 1.8.0
sentencepiece != 0.1.92
protobuf
accelerate == 0.13.2
transformers

View File

@ -413,7 +413,11 @@ def main():
cai_version = colossalai.__version__
logger.info(f'using Colossal-AI version {cai_version}')
if version.parse(cai_version) > version.parse("0.1.10"):
try:
from colossalai.nn.parallel import GeminiDDP
except ImportError:
# this works for unreleased main branch, and this may be released on 0.2.9
from colossalai.zero import GeminiDDP
model = GeminiDDP(model, device=get_current_device(), placement_policy=PLACEMENT_POLICY, pin_memory=True)
elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"):
from colossalai.gemini import ChunkManager, GeminiManager

View File

@ -0,0 +1,21 @@
#!/bin/bash
set -xue
pip install -r requirements.txt
BS=8
MEMCAP=0
GPUNUM=2
MODLE="facebook/opt-125m"
torchrun \
--nproc_per_node ${GPUNUM} \
--master_port 19198 \
run_clm.py \
-s \
--output_dir $PWD \
--mem_cap ${MEMCAP} \
--model_name_or_path ${MODLE} \
--per_device_train_batch_size ${BS} \
--num_train_epochs 1

View File

@ -0,0 +1,3 @@
#!/bin/bash
cd opt && bash test_ci.sh