Browse Source

[example] update examples related to zero/gemini (#3431)

* [zero] update legacy import

* [zero] update examples

* [example] fix opt tutorial

* [example] fix opt tutorial

* [example] fix opt tutorial

* [example] fix opt tutorial

* [example] fix import
pull/3445/head
ver217 2 years ago committed by GitHub
parent
commit
573af84184
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 3
      colossalai/zero/legacy/__init__.py
  2. 7
      examples/language/roberta/configs/colossalai_ddp.py
  3. 9
      examples/language/roberta/configs/colossalai_zero.py
  4. 6
      examples/tutorial/opt/opt/colossalai_zero.py
  5. 1
      examples/tutorial/opt/opt/requirements.txt
  6. 6
      examples/tutorial/opt/opt/run_clm.py
  7. 21
      examples/tutorial/opt/opt/test_ci.sh
  8. 3
      examples/tutorial/opt/test_ci.sh

3
colossalai/zero/legacy/__init__.py

@ -6,6 +6,7 @@ import torch.nn as nn
from colossalai.logging import get_dist_logger
from .init_ctx import ZeroInitContext, no_shard_zero_context, no_shard_zero_decrator
from .shard_utils import BucketTensorShardStrategy, TensorShardStrategy
from .sharded_model import ShardedModelV2
from .sharded_optim import ShardedOptimizerV2
@ -40,5 +41,5 @@ def convert_to_zero_v2(model: nn.Module, optimizer: torch.optim.Optimizer, model
__all__ = [
'convert_to_zero_v2', 'ShardedModelV2', 'ShardedOptimizerV2', 'ZeroInitContext', 'no_shard_zero_context',
'no_shard_zero_decrator'
'no_shard_zero_decrator', 'TensorShardStrategy', 'BucketTensorShardStrategy'
]

7
examples/language/roberta/configs/colossalai_ddp.py

@ -1,4 +1,9 @@
from colossalai.zero.shard_utils import TensorShardStrategy
from colossalai.nn.optimizer import FusedAdam
try:
from colossalai.zero.shard_utils import TensorShardStrategy
except ImportError:
# colossalai > 0.2.8
from colossalai.zero.legacy import TensorShardStrategy
clip_grad_norm = 1.0

9
examples/language/roberta/configs/colossalai_zero.py

@ -1,6 +1,11 @@
from colossalai.zero.shard_utils import TensorShardStrategy
from colossalai.nn.optimizer import FusedAdam
try:
from colossalai.zero.shard_utils import TensorShardStrategy
except ImportError:
# colossalai > 0.2.8
from colossalai.zero.legacy import TensorShardStrategy
# fp16 = dict(
# mode=AMP_TYPE.TORCH,
# )
@ -29,4 +34,4 @@ optimizer = dict(
weight_decay=1e-2,
)
# 64433
# 64433

6
examples/tutorial/opt/opt/colossalai_zero.py

@ -1,4 +1,8 @@
from colossalai.zero.shard_utils import TensorShardStrategy
try:
from colossalai.zero.shard_utils import TensorShardStrategy
except ImportError:
# colossalai > 0.2.8
from colossalai.zero.legacy import TensorShardStrategy
zero = dict(model_config=dict(shard_strategy=TensorShardStrategy(),
tensor_placement_policy="auto",

1
examples/tutorial/opt/opt/requirements.txt

@ -4,3 +4,4 @@ datasets >= 1.8.0
sentencepiece != 0.1.92
protobuf
accelerate == 0.13.2
transformers

6
examples/tutorial/opt/opt/run_clm.py

@ -413,7 +413,11 @@ def main():
cai_version = colossalai.__version__
logger.info(f'using Colossal-AI version {cai_version}')
if version.parse(cai_version) > version.parse("0.1.10"):
from colossalai.nn.parallel import GeminiDDP
try:
from colossalai.nn.parallel import GeminiDDP
except ImportError:
# this works for unreleased main branch, and this may be released on 0.2.9
from colossalai.zero import GeminiDDP
model = GeminiDDP(model, device=get_current_device(), placement_policy=PLACEMENT_POLICY, pin_memory=True)
elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"):
from colossalai.gemini import ChunkManager, GeminiManager

21
examples/tutorial/opt/opt/test_ci.sh

@ -0,0 +1,21 @@
#!/bin/bash
set -xue
pip install -r requirements.txt
BS=8
MEMCAP=0
GPUNUM=2
MODLE="facebook/opt-125m"
torchrun \
--nproc_per_node ${GPUNUM} \
--master_port 19198 \
run_clm.py \
-s \
--output_dir $PWD \
--mem_cap ${MEMCAP} \
--model_name_or_path ${MODLE} \
--per_device_train_batch_size ${BS} \
--num_train_epochs 1

3
examples/tutorial/opt/test_ci.sh

@ -0,0 +1,3 @@
#!/bin/bash
cd opt && bash test_ci.sh
Loading…
Cancel
Save