Browse Source

[misc] refactor launch API and tensor constructor (#5666)

* [misc] remove config arg from initialize

* [misc] remove old tensor contrusctor

* [plugin] add npu support for ddp

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [devops] fix doc test ci

* [test] fix test launch

* [doc] update launch doc

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
pull/5541/head
Hongxin Liu 7 months ago committed by GitHub
parent
commit
7f8b16635b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 2
      .github/workflows/doc_test_on_pr.yml
  2. 2
      applications/Colossal-LLaMA/train.py
  3. 2
      applications/ColossalChat/benchmarks/benchmark_ppo.py
  4. 2
      applications/ColossalChat/examples/training_scripts/train_dpo.py
  5. 2
      applications/ColossalChat/examples/training_scripts/train_ppo.py
  6. 2
      applications/ColossalChat/examples/training_scripts/train_rm.py
  7. 2
      applications/ColossalChat/examples/training_scripts/train_sft.py
  8. 2
      applications/ColossalEval/examples/dataset_evaluation/inference.py
  9. 2
      applications/ColossalEval/examples/gpt_evaluation/inference.py
  10. 8
      applications/ColossalMoE/infer.py
  11. 2
      applications/ColossalMoE/tests/test_mixtral_layer.py
  12. 2
      applications/ColossalMoE/tests/test_moe_checkpoint.py
  13. 8
      applications/ColossalMoE/train.py
  14. 2
      colossalai/auto_parallel/offload/amp_optimizer.py
  15. 4
      colossalai/auto_parallel/offload/base_offload_module.py
  16. 5
      colossalai/booster/plugin/torch_ddp_plugin.py
  17. 2
      colossalai/inference/README.md
  18. 16
      colossalai/initialize.py
  19. 2
      colossalai/legacy/inference/dynamic_batching/ray_dist_init.py
  20. 2
      colossalai/legacy/inference/hybridengine/engine.py
  21. 34
      colossalai/legacy/inference/pipeline/README.md
  22. 2
      colossalai/legacy/inference/pipeline/benchmark/benchmark.py
  23. 2
      colossalai/legacy/inference/serving/ray_serve/Colossal_Inference_rayserve.py
  24. 2
      colossalai/legacy/inference/serving/torch_serve/Colossal_Inference_Handler.py
  25. 2
      colossalai/legacy/pipeline/rpc/utils.py
  26. 4
      colossalai/nn/optimizer/fused_adam.py
  27. 4
      colossalai/nn/optimizer/hybrid_adam.py
  28. 2
      colossalai/shardformer/README.md
  29. 2
      colossalai/shardformer/examples/convergence_benchmark.py
  30. 3
      colossalai/shardformer/examples/performance_benchmark.py
  31. 2
      colossalai/shardformer/shard/shardformer.py
  32. 2
      colossalai/tensor/d_tensor/README.md
  33. 2
      docs/source/en/advanced_tutorials/train_gpt_using_hybrid_parallelism.md
  34. 2
      docs/source/en/advanced_tutorials/train_vit_with_hybrid_parallelism.md
  35. 2
      docs/source/en/basics/booster_api.md
  36. 18
      docs/source/en/basics/launch_colossalai.md
  37. 2
      docs/source/en/features/gradient_accumulation_with_booster.md
  38. 2
      docs/source/en/features/gradient_clipping_with_booster.md
  39. 2
      docs/source/en/features/lazy_init.md
  40. 10
      docs/source/en/features/mixed_precision_training_with_booster.md
  41. 2
      docs/source/en/features/nvme_offload.md
  42. 2
      docs/source/en/features/zero_with_chunk.md
  43. 2
      docs/source/zh-Hans/advanced_tutorials/train_gpt_using_hybrid_parallelism.md
  44. 2
      docs/source/zh-Hans/advanced_tutorials/train_vit_with_hybrid_parallelism.md
  45. 2
      docs/source/zh-Hans/basics/booster_api.md
  46. 18
      docs/source/zh-Hans/basics/launch_colossalai.md
  47. 2
      docs/source/zh-Hans/features/gradient_accumulation_with_booster.md
  48. 2
      docs/source/zh-Hans/features/gradient_clipping_with_booster.md
  49. 2
      docs/source/zh-Hans/features/lazy_init.md
  50. 12
      docs/source/zh-Hans/features/mixed_precision_training_with_booster.md
  51. 2
      docs/source/zh-Hans/features/nvme_offload.md
  52. 2
      docs/source/zh-Hans/features/zero_with_chunk.md
  53. 4
      examples/community/roberta/pretraining/run_pretraining.py
  54. 2
      examples/images/dreambooth/debug.py
  55. 4
      examples/images/dreambooth/train_dreambooth_colossalai.py
  56. 4
      examples/images/dreambooth/train_dreambooth_colossalai_lora.py
  57. 2
      examples/images/resnet/train.py
  58. 2
      examples/images/vit/vit_benchmark.py
  59. 2
      examples/images/vit/vit_train_demo.py
  60. 2
      examples/inference/benchmark_llama.py
  61. 2
      examples/inference/run_llama_inference.py
  62. 2
      examples/language/bert/benchmark.py
  63. 2
      examples/language/bert/finetune.py
  64. 3
      examples/language/gpt/experiments/auto_offload/train_gpt_offload.py
  65. 2
      examples/language/gpt/experiments/auto_parallel/auto_parallel_with_gpt.py
  66. 2
      examples/language/gpt/gemini/train_gpt_demo.py
  67. 2
      examples/language/gpt/hybridparallelism/benchmark.py
  68. 2
      examples/language/gpt/hybridparallelism/finetune.py
  69. 4
      examples/language/gpt/titans/train_gpt.py
  70. 2
      examples/language/grok-1/inference_tp.py
  71. 2
      examples/language/llama/benchmark.py
  72. 2
      examples/language/openmoe/benchmark/benchmark_cai.py
  73. 2
      examples/language/openmoe/train.py
  74. 2
      examples/language/opt/opt_benchmark.py
  75. 2
      examples/language/opt/opt_train_demo.py
  76. 2
      examples/language/palm/train.py
  77. 2
      examples/tutorial/auto_parallel/auto_ckpt_batchsize_test.py
  78. 2
      examples/tutorial/auto_parallel/auto_ckpt_solver_test.py
  79. 2
      examples/tutorial/new_api/cifar_resnet/train.py
  80. 2
      examples/tutorial/new_api/cifar_vit/train.py
  81. 2
      examples/tutorial/new_api/glue_bert/finetune.py
  82. 2
      examples/tutorial/opt/opt/run_clm.py
  83. 2
      tests/test_auto_parallel/test_ckpt_solvers/test_C_solver_consistency.py
  84. 4
      tests/test_auto_parallel/test_ckpt_solvers/test_ckpt_torchvision.py
  85. 3
      tests/test_auto_parallel/test_offload/test_perf.py
  86. 4
      tests/test_auto_parallel/test_tensor_shard/test_bias_addition_forward.py
  87. 2
      tests/test_auto_parallel/test_tensor_shard/test_checkpoint.py
  88. 2
      tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_ddp.py
  89. 2
      tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py
  90. 2
      tests/test_auto_parallel/test_tensor_shard/test_gpt/test_runtime_with_gpt_modules.py
  91. 2
      tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_binary_elementwise_metainfo.py
  92. 4
      tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_conv_metainfo.py
  93. 4
      tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_linear_metainfo.py
  94. 2
      tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_norm_metainfo.py
  95. 4
      tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_pooling_metainfo.py
  96. 4
      tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addbmm_handler.py
  97. 2
      tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addmm_handler.py
  98. 2
      tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_batch_norm_handler.py
  99. 2
      tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_function_node.py
  100. 2
      tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_module_node.py
  101. Some files were not shown because too many files have changed in this diff Show More

2
.github/workflows/doc_test_on_pr.yml

@ -56,7 +56,7 @@ jobs:
needs: detect-changed-doc
runs-on: [self-hosted, gpu]
container:
image: hpcaitech/pytorch-cuda:2.0.0-11.7.0
image: hpcaitech/pytorch-cuda:2.1.0-12.1.0
options: --gpus all --rm
timeout-minutes: 20
defaults:

2
applications/Colossal-LLaMA/train.py

@ -136,7 +136,7 @@ def main() -> None:
# ==============================
# Initialize Distributed Training
# ==============================
colossalai.launch_from_torch({})
colossalai.launch_from_torch()
accelerator = get_accelerator()
coordinator = DistCoordinator()

2
applications/ColossalChat/benchmarks/benchmark_ppo.py

@ -66,7 +66,7 @@ def benchmark_train(args):
# ==============================
# Initialize Distributed Training
# ==============================
colossalai.launch_from_torch({})
colossalai.launch_from_torch()
coordinator = DistCoordinator()
# ======================================================

2
applications/ColossalChat/examples/training_scripts/train_dpo.py

@ -37,7 +37,7 @@ def train(args):
# ==============================
# Initialize Distributed Training
# ==============================
colossalai.launch_from_torch({})
colossalai.launch_from_torch()
coordinator = DistCoordinator()
# ==============================

2
applications/ColossalChat/examples/training_scripts/train_ppo.py

@ -39,7 +39,7 @@ def train(args):
# ==============================
# Initialize Distributed Training
# ==============================
colossalai.launch_from_torch({})
colossalai.launch_from_torch()
coordinator = DistCoordinator()
# ======================================================

2
applications/ColossalChat/examples/training_scripts/train_rm.py

@ -34,7 +34,7 @@ def train(args):
# ==============================
# Initialize Distributed Training
# ==============================
colossalai.launch_from_torch({})
colossalai.launch_from_torch()
coordinator = DistCoordinator()
# ======================================================

2
applications/ColossalChat/examples/training_scripts/train_sft.py

@ -29,7 +29,7 @@ def train(args):
# ==============================
# Initialize Distributed Training
# ==============================
colossalai.launch_from_torch({})
colossalai.launch_from_torch()
coordinator = DistCoordinator()
# ==============================

2
applications/ColossalEval/examples/dataset_evaluation/inference.py

@ -81,7 +81,7 @@ def rm_and_merge(
def main(args):
colossalai.launch_from_torch(config={}, seed=42)
colossalai.launch_from_torch(seed=42)
accelerator = get_accelerator()
world_size = dist.get_world_size()

2
applications/ColossalEval/examples/gpt_evaluation/inference.py

@ -81,7 +81,7 @@ def rm_and_merge(
def main(args):
colossalai.launch_from_torch(config={}, seed=42)
colossalai.launch_from_torch(seed=42)
world_size = dist.get_world_size()
rank = dist.get_rank()

8
applications/ColossalMoE/infer.py

@ -57,7 +57,7 @@ def main():
args = parse_args()
# Launch ColossalAI
colossalai.launch_from_torch(config={}, seed=args.seed)
colossalai.launch_from_torch(seed=args.seed)
coordinator = DistCoordinator()
config = MixtralConfig.from_pretrained(args.model_name)
@ -96,7 +96,11 @@ def main():
if coordinator.rank == 0:
text = ["Hello my name is"]
else:
text = ["What's the largest country in the world?", "How many people live in China?", "帮我续写这首诗:离离原上草"]
text = [
"What's the largest country in the world?",
"How many people live in China?",
"帮我续写这首诗:离离原上草",
]
tokenizer.pad_token = tokenizer.unk_token
inputs = tokenizer(text, return_tensors="pt", padding=True).to(torch.cuda.current_device())

2
applications/ColossalMoE/tests/test_mixtral_layer.py

@ -50,7 +50,7 @@ def check_mixtral_moe_layer():
def run_dist(rank: int, world_size: int, port: int):
colossalai.launch({}, rank, world_size, "localhost", port)
colossalai.launch(rank, world_size, "localhost", port)
check_mixtral_moe_layer()

2
applications/ColossalMoE/tests/test_moe_checkpoint.py

@ -133,7 +133,7 @@ def check_mixtral_moe_layer():
def run_dist(rank: int, world_size: int, port: int):
colossalai.launch({}, rank, world_size, "localhost", port)
colossalai.launch(rank, world_size, "localhost", port)
check_mixtral_moe_layer()

8
applications/ColossalMoE/train.py

@ -145,7 +145,7 @@ def main():
args = parse_args()
# Launch ColossalAI
colossalai.launch_from_torch(config={}, seed=args.seed)
colossalai.launch_from_torch(seed=args.seed)
coordinator = DistCoordinator()
# Set plugin
@ -195,9 +195,9 @@ def main():
lr_scheduler = CosineAnnealingWarmupLR(
optimizer=optimizer,
total_steps=args.num_epochs * len(dataloader),
warmup_steps=args.warmup_steps
if args.warmup_steps is not None
else int(args.num_epochs * len(dataloader) * 0.025),
warmup_steps=(
args.warmup_steps if args.warmup_steps is not None else int(args.num_epochs * len(dataloader) * 0.025)
),
eta_min=0.1 * args.lr,
)

2
colossalai/auto_parallel/offload/amp_optimizer.py

@ -126,7 +126,7 @@ class AMPOptimizer(OptimizerWrapper):
return self.grad_scaler.scale.item()
def zero_grad(self, *args, **kwargs):
self.module.overflow_counter = torch.cuda.IntTensor([0])
self.module.overflow_counter = torch.tensor([0], dtype=torch.int, device=get_accelerator().get_current_device())
return self.optim.zero_grad(set_to_none=True)
def step(self, *args, **kwargs):

4
colossalai/auto_parallel/offload/base_offload_module.py

@ -4,7 +4,7 @@ from typing import Optional, Set
import torch
import torch.nn as nn
from colossalai.utils import _cast_float
from colossalai.utils import _cast_float, get_current_device
from colossalai.utils.common import free_storage
from .region_manager import RegionManager
@ -25,7 +25,7 @@ class BaseOffloadModule:
self.model = model
self.region_manager = region_manager
self.grad_hook_list = []
self.overflow_counter = torch.cuda.IntTensor([0])
self.overflow_counter = torch.tensor([0], dtype=torch.int, device=get_current_device())
self.grad_offload_stream = torch.cuda.current_stream() if is_sync else GlobalRuntimeInfo.d2h_stream

5
colossalai/booster/plugin/torch_ddp_plugin.py

@ -10,6 +10,7 @@ from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
from colossalai.cluster import DistCoordinator
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.quantization import BnbQuantizationConfig, quantize_model
from colossalai.utils import get_current_device
from .dp_plugin_base import DPPluginBase
@ -203,7 +204,7 @@ class TorchDDPPlugin(DPPluginBase):
return True
def supported_devices(self) -> List[str]:
return ["cuda"]
return ["cuda", "npu"]
def configure(
self,
@ -214,7 +215,7 @@ class TorchDDPPlugin(DPPluginBase):
lr_scheduler: Optional[LRScheduler] = None,
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
# cast model to cuda
model = model.cuda()
model = model.to(get_current_device())
# convert model to sync bn
model = nn.SyncBatchNorm.convert_sync_batchnorm(model, None)

2
colossalai/inference/README.md

@ -114,7 +114,7 @@ import colossalai
from transformers import LlamaForCausalLM, LlamaTokenizer
#launch distributed environment
colossalai.launch_from_torch(config={})
colossalai.launch_from_torch()
# load original model and tokenizer
model = LlamaForCausalLM.from_pretrained("/path/to/model")

16
colossalai/initialize.py

@ -2,20 +2,15 @@
# -*- encoding: utf-8 -*-
import os
import warnings
from pathlib import Path
from typing import Dict, Union
import torch.distributed as dist
from colossalai.accelerator import get_accelerator
from colossalai.context import Config
from colossalai.logging import get_dist_logger
from colossalai.utils import set_seed
def launch(
config: Union[str, Path, Config, Dict],
rank: int,
world_size: int,
host: str,
@ -44,8 +39,6 @@ def launch(
Raises:
Exception: Raise exception when config type is wrong
"""
if rank == 0:
warnings.warn("`config` is deprecated and will be removed soon.")
cur_accelerator = get_accelerator()
@ -68,7 +61,6 @@ def launch(
def launch_from_slurm(
config: Union[str, Path, Config, Dict],
host: str,
port: int,
backend: str = "nccl",
@ -95,7 +87,6 @@ def launch_from_slurm(
)
launch(
config=config,
rank=rank,
world_size=world_size,
host=host,
@ -107,7 +98,6 @@ def launch_from_slurm(
def launch_from_openmpi(
config: Union[str, Path, Config, Dict],
host: str,
port: int,
backend: str = "nccl",
@ -135,7 +125,6 @@ def launch_from_openmpi(
)
launch(
config=config,
local_rank=local_rank,
rank=rank,
world_size=world_size,
@ -147,9 +136,7 @@ def launch_from_openmpi(
)
def launch_from_torch(
config: Union[str, Path, Config, Dict], backend: str = "nccl", seed: int = 1024, verbose: bool = True
):
def launch_from_torch(backend: str = "nccl", seed: int = 1024, verbose: bool = True):
"""A wrapper for colossalai.launch for torchrun or torch.distributed.launch by reading rank and world size
from the environment variables set by PyTorch
@ -171,7 +158,6 @@ def launch_from_torch(
)
launch(
config=config,
local_rank=local_rank,
rank=rank,
world_size=world_size,

2
colossalai/legacy/inference/dynamic_batching/ray_dist_init.py

@ -56,7 +56,7 @@ class Worker:
# initialize a ray collective group, otherwise colossalai distributed env won't be built successfully
collective.init_collective_group(world_size, rank, "nccl", "default")
# initialize and set distributed environment
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
ray_serve_logger.info(f"Worker with rank {rank} (world size {world_size}) setting up..")
log_cuda_info("Worker.setup")

2
colossalai/legacy/inference/hybridengine/engine.py

@ -42,7 +42,7 @@ class CaiInferEngine:
import colossalai
from transformers import LlamaForCausalLM, LlamaTokenizer
colossalai.launch_from_torch(config={})
colossalai.launch_from_torch()
model = LlamaForCausalLM.from_pretrained("your_path_to_model")
tokenizer = LlamaTokenizer.from_pretrained("/home/lczyh/share/models/llama-7b-hf")

34
colossalai/legacy/inference/pipeline/README.md

@ -36,7 +36,7 @@ from colossalai.inference.pipeline.policies import LlamaModelInferPolicy
import colossalai
from transformers import LlamaForCausalLM, LlamaTokenizer
colossalai.launch_from_torch(config={})
colossalai.launch_from_torch()
model = LlamaForCausalLM.from_pretrained("/path/to/model")
tokenizer = LlamaTokenizer.from_pretrained("/path/to/model")
@ -57,27 +57,27 @@ We conducted multiple benchmark tests to evaluate the performance. We compared t
### Llama Throughput (tokens/s) | input length=1024, output length=128
#### A10 7b, fp16
| batch_size(micro_batch size)| 2(1) | 4(2) | 8(4) | 16(8) | 32(8) | 32(16)|
| :---: | :---: | :---: | :---: | :---: | :---: | :---:|
| Pipeline Inference | 40.35 | 77.1 | 139.03 | 232.7 | 257.81 | OOM |
| Hugging Face | 41.43 | 65.30 | 91.93 | 114.62 | OOM| OOM |
| batch_size(micro_batch size) | 2(1) | 4(2) | 8(4) | 16(8) | 32(8) | 32(16) |
|:----------------------------:|:-----:|:-----:|:------:|:------:|:------:|:------:|
| Pipeline Inference | 40.35 | 77.1 | 139.03 | 232.7 | 257.81 | OOM |
| Hugging Face | 41.43 | 65.30 | 91.93 | 114.62 | OOM | OOM |
#### A10 13b, fp16
| batch_size(micro_batch size)| 2(1) | 4(2) | 8(4) | 16(4) |
| :---: | :---: | :---: | :---: | :---: |
| Pipeline Inference | 25.39 | 47.09 | 83.7 | 89.46 |
| Hugging Face | 23.48 | 37.59 | 53.44 | OOM |
| batch_size(micro_batch size) | 2(1) | 4(2) | 8(4) | 16(4) |
|:----------------------------:|:-----:|:-----:|:-----:|:-----:|
| Pipeline Inference | 25.39 | 47.09 | 83.7 | 89.46 |
| Hugging Face | 23.48 | 37.59 | 53.44 | OOM |
#### A800 7b, fp16
| batch_size(micro_batch size) | 2(1) | 4(2) | 8(4) | 16(8) | 32(16) |
| :---: | :---: | :---: | :---: | :---: | :---: |
| Pipeline Inference| 57.97 | 110.13 | 213.33 | 389.86 | 670.12 |
| Hugging Face | 42.44 | 76.5 | 151.97 | 212.88 | 256.13 |
| batch_size(micro_batch size) | 2(1) | 4(2) | 8(4) | 16(8) | 32(16) |
|:----------------------------:|:-----:|:------:|:------:|:------:|:------:|
| Pipeline Inference | 57.97 | 110.13 | 213.33 | 389.86 | 670.12 |
| Hugging Face | 42.44 | 76.5 | 151.97 | 212.88 | 256.13 |
#### A800 13b, fp16
| batch_size(micro_batch size) | 2(1) | 4(2) | 8(4) | 16(8) | 32(16) |
| :---: | :---: | :---: | :---: | :---: | :---: |
| Pipeline Inference | 41.78 | 94.18 | 172.67| 310.75| 470.15 |
| Hugging Face | 36.57 | 68.4 | 105.81 | 139.51 | 166.34 |
| batch_size(micro_batch size) | 2(1) | 4(2) | 8(4) | 16(8) | 32(16) |
|:----------------------------:|:-----:|:-----:|:------:|:------:|:------:|
| Pipeline Inference | 41.78 | 94.18 | 172.67 | 310.75 | 470.15 |
| Hugging Face | 36.57 | 68.4 | 105.81 | 139.51 | 166.34 |

2
colossalai/legacy/inference/pipeline/benchmark/benchmark.py

@ -12,7 +12,7 @@ from colossalai.inference.pipeline.policies import LlamaModelInferPolicy
GIGABYTE = 1024**3
MEGABYTE = 1024 * 1024
colossalai.launch_from_torch(config={})
colossalai.launch_from_torch()
def data_gen(batch_size: int = 4, seq_len: int = 512):

2
colossalai/legacy/inference/serving/ray_serve/Colossal_Inference_rayserve.py

@ -56,7 +56,7 @@ class Worker:
# initialize a ray collective group, otherwise colossalai distributed env won't be built successfully
collective.init_collective_group(world_size, rank, "nccl", "default")
# initialize and set distributed environment
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
ray_serve_logger.info(f"Worker with rank {rank} (world size {world_size}) setting up..")
log_cuda_info("Worker.setup")

2
colossalai/legacy/inference/serving/torch_serve/Colossal_Inference_Handler.py

@ -98,7 +98,7 @@ class ColossalInferenceHandler(BaseHandler, ABC):
self.model.cuda()
self.model.eval()
colossalai.launch(config={}, rank=rank, world_size=world_size, host=host, port=port, backend="nccl")
colossalai.launch(rank=rank, world_size=world_size, host=host, port=port, backend="nccl")
logger.info("Initializing TPInferEngine ...")
shard_config = ShardConfig(
enable_tensor_parallelism=True if self.tp_size > 1 else False, extra_kwargs={"inference_only": True}

2
colossalai/legacy/pipeline/rpc/utils.py

@ -114,7 +114,7 @@ def run_worker(rank, args, master_func):
port = args.master_port
backend = "nccl" if device == "cuda" else "gloo"
launch(dict(), rank, world_size, host, int(port), backend, verbose=False)
launch(rank, world_size, host, int(port), backend, verbose=False)
ppg.set_global_info(
rank=rank,
world_size=world_size,

4
colossalai/nn/optimizer/fused_adam.py

@ -8,7 +8,7 @@ Licensed under the MIT License.
"""
import torch
from colossalai.utils import multi_tensor_applier
from colossalai.utils import get_current_device, multi_tensor_applier
class FusedAdam(torch.optim.Optimizer):
@ -75,7 +75,7 @@ class FusedAdam(torch.optim.Optimizer):
fused_optim = FusedOptimizerLoader().load()
# Skip buffer
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
self._dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device=get_current_device())
self.multi_tensor_adam = fused_optim.multi_tensor_adam
else:
raise RuntimeError("FusedAdam requires cuda extensions")

4
colossalai/nn/optimizer/hybrid_adam.py

@ -3,7 +3,7 @@ from typing import Any, Optional
import torch
from colossalai.kernel.kernel_loader import FusedOptimizerLoader
from colossalai.utils import multi_tensor_applier
from colossalai.utils import get_current_device, multi_tensor_applier
from .cpu_adam import CPUAdam
@ -87,7 +87,7 @@ class HybridAdam(CPUAdam):
if torch.cuda.is_available():
fused_optim = FusedOptimizerLoader().load()
self.gpu_adam_op = fused_optim.multi_tensor_adam
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
self._dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device=get_current_device())
@torch.no_grad()
def step(self, closure=None, div_scale: float = -1):

2
colossalai/shardformer/README.md

@ -38,7 +38,7 @@ from transformers import BertForMaskedLM
import colossalai
# launch colossalai
colossalai.launch_from_torch(config={})
colossalai.launch_from_torch()
# create model
config = BertConfig.from_pretrained('bert-base-uncased')

2
colossalai/shardformer/examples/convergence_benchmark.py

@ -28,7 +28,7 @@ def to_device(x: Any, device: torch.device) -> Any:
def train(args):
colossalai.launch_from_torch(config={}, seed=42)
colossalai.launch_from_torch(seed=42)
coordinator = DistCoordinator()
# prepare for data and dataset

3
colossalai/shardformer/examples/performance_benchmark.py

@ -1,6 +1,7 @@
"""
Shardformer Benchmark
"""
import torch
import torch.distributed as dist
import transformers
@ -84,5 +85,5 @@ def bench_shardformer(BATCH, N_CTX, provider, model_func, dtype=torch.float32, d
# start benchmark, command:
# torchrun --standalone --nproc_per_node=2 performance_benchmark.py
if __name__ == "__main__":
colossalai.launch_from_torch({})
colossalai.launch_from_torch()
bench_shardformer.run(save_path=".", print_data=dist.get_rank() == 0)

2
colossalai/shardformer/shard/shardformer.py

@ -26,7 +26,7 @@ class ShardFormer:
import colossalai
import torch
colossalai.launch_from_torch(config={})
colossalai.launch_from_torch()
org_model = BertForMaskedLM.from_pretrained('bert-base-uncased')
shard_config = ShardConfig()

2
colossalai/tensor/d_tensor/README.md

@ -69,7 +69,7 @@ import colossalai
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.d_tensor import DTensor, ShardingSpec
colossalai.launch_from_torch(config={})
colossalai.launch_from_torch()
# define your device mesh
# assume you have 4 GPUs

2
docs/source/en/advanced_tutorials/train_gpt_using_hybrid_parallelism.md

@ -75,7 +75,7 @@ WARMUP_FRACTION = 0.1
we create a distributed environment.
```python
# Launch ColossalAI
colossalai.launch_from_torch(config={}, seed=42)
colossalai.launch_from_torch( seed=42)
coordinator = DistCoordinator()
```
prepare the dataset. You can use `plugin.prepare_dataloader` to generate a dataloader or customize your own dataloader.

2
docs/source/en/advanced_tutorials/train_vit_with_hybrid_parallelism.md

@ -71,7 +71,7 @@ PP_SIZE = 2
Create a distributed environment.
```python
# Launch ColossalAI
colossalai.launch_from_torch(config={}, seed=SEEDå)
colossalai.launch_from_torch( seed=SEEDå)
coordinator = DistCoordinator()
world_size = coordinator.world_size
```

2
docs/source/en/basics/booster_api.md

@ -55,7 +55,7 @@ from colossalai.booster.plugin import TorchDDPPlugin
def train():
# launch colossalai
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
colossalai.launch(rank=rank, world_size=world_size, port=port, host='localhost')
# create plugin and objects for training
plugin = TorchDDPPlugin()

18
docs/source/en/basics/launch_colossalai.md

@ -87,8 +87,7 @@ import colossalai
args = colossalai.get_default_parser().parse_args()
# launch distributed environment
colossalai.launch(config=args.config,
rank=args.rank,
colossalai.launch(rank=args.rank,
world_size=args.world_size,
host=args.host,
port=args.port,
@ -106,20 +105,11 @@ First, we need to set the launch method in our code. As this is a wrapper of the
use `colossalai.launch_from_torch`. The arguments required for distributed environment such as rank, world size, host and port are all set by the PyTorch
launcher and can be read from the environment variable directly.
config.py
```python
BATCH_SIZE = 512
LEARNING_RATE = 3e-3
WEIGHT_DECAY = 0.3
NUM_EPOCHS = 2
```
train.py
```python
import colossalai
colossalai.launch_from_torch(
config="./config.py",
)
colossalai.launch_from_torch()
...
```
@ -203,7 +193,6 @@ Do this in your training script:
import colossalai
colossalai.launch_from_slurm(
config=<CONFIG>,
host=args.host,
port=args.port
)
@ -224,7 +213,6 @@ use them to start the distributed backend.
Do this in your train.py:
```python
colossalai.launch_from_openmpi(
config=<CONFIG>,
host=args.host,
port=args.port
)
@ -238,3 +226,5 @@ mpirun --hostfile <my_hostfile> -np <num_process> python train.py --host <node n
- --hostfile: use this option to specify a list of hosts on which to run
- --np: set the number of processes (GPUs) to launch in total. For example, if --np 4, 4 python processes will be initialized to run train.py.
<!-- doc-test-command: echo -->

2
docs/source/en/features/gradient_accumulation_with_booster.md

@ -45,7 +45,7 @@ We then need to initialize distributed environment. For demo purpose, we uses `l
parser = colossalai.get_default_parser()
args = parser.parse_args()
# launch from torch
colossalai.launch_from_torch(config=dict())
colossalai.launch_from_torch()
```
### Step 3. Create training components

2
docs/source/en/features/gradient_clipping_with_booster.md

@ -61,7 +61,7 @@ We then need to initialize distributed environment. For demo purpose, we uses `l
for other initialization methods.
```python
colossalai.launch_from_torch(config=dict())
colossalai.launch_from_torch()
logger = get_dist_logger()
```

2
docs/source/en/features/lazy_init.md

@ -29,7 +29,7 @@ from colossalai.booster.plugin import GeminiPlugin
from transformers import LlamaForCausalLM, LlamaConfig, BertForPreTraining
colossalai.launch({})
colossalai.launch()
plugin = GeminiPlugin()
booster = Booster(plugin)

10
docs/source/en/features/mixed_precision_training_with_booster.md

@ -20,10 +20,10 @@ In Colossal-AI, we have incorporated different implementations of mixed precisio
3. naive amp
| Colossal-AI | support tensor parallel | support pipeline parallel | fp16 extent |
| -------------- | ----------------------- | ------------------------- | ---------------------------------------------------------------------------------------------------- |
| AMP_TYPE.TORCH | ✅ | ❌ | Model parameters, activation, gradients are downcast to fp16 during forward and backward propagation |
| AMP_TYPE.APEX | ❌ | ❌ | More fine-grained, we can choose opt_level O0, O1, O2, O3 |
| AMP_TYPE.NAIVE | ✅ | ✅ | Model parameters, forward and backward operations are all downcast to fp16 |
|----------------|-------------------------|---------------------------|------------------------------------------------------------------------------------------------------|
| AMP_TYPE.TORCH | ✅ | ❌ | Model parameters, activation, gradients are downcast to fp16 during forward and backward propagation |
| AMP_TYPE.APEX | ❌ | ❌ | More fine-grained, we can choose opt_level O0, O1, O2, O3 |
| AMP_TYPE.NAIVE | ✅ | ✅ | Model parameters, forward and backward operations are all downcast to fp16 |
The first two rely on the original implementation of PyTorch (version 1.6 and above) and NVIDIA Apex.
The last method is similar to Apex O2 level.
@ -164,7 +164,7 @@ parser = colossalai.get_default_parser()
args = parser.parse_args()
# launch from torch
colossalai.launch_from_torch(config=dict())
colossalai.launch_from_torch()
```

2
docs/source/en/features/nvme_offload.md

@ -185,7 +185,7 @@ Then we can train GPT model with Gemini. The placement policy of Gemini should b
```python
def train_gemini_cpu(nvme_offload_fraction: float = 0.0):
colossalai.launch_from_torch({})
colossalai.launch_from_torch()
config = GPT2Config()
with ColoInitContext(device=torch.cuda.current_device()):
model = GPT2LMHeadModel(config)

2
docs/source/en/features/zero_with_chunk.md

@ -174,7 +174,7 @@ def main():
SEQ_LEN = 1024
VOCAB_SIZE = 50257
NUM_STEPS = 10
colossalai.launch_from_torch(config={})
colossalai.launch_from_torch()
# build criterion
criterion = GPTLMLoss()

2
docs/source/zh-Hans/advanced_tutorials/train_gpt_using_hybrid_parallelism.md

@ -62,7 +62,7 @@ plugin = HybridParallelPlugin(
## 创建分布式环境.
```python
# Launch ColossalAI
colossalai.launch_from_torch(config={}, seed=42)
colossalai.launch_from_torch(seed=42)
coordinator = DistCoordinator()
```
## 定义GPT-2模型的训练组件

2
docs/source/zh-Hans/advanced_tutorials/train_vit_with_hybrid_parallelism.md

@ -70,7 +70,7 @@ PP_SIZE = 2
首先我们创建一个分布式环境
```python
# Launch ColossalAI
colossalai.launch_from_torch(config={}, seed=SEEDå)
colossalai.launch_from_torch(seed=SEEDå)
coordinator = DistCoordinator()
world_size = coordinator.world_size
```

2
docs/source/zh-Hans/basics/booster_api.md

@ -60,7 +60,7 @@ from colossalai.booster.plugin import TorchDDPPlugin
def train():
# launch colossalai
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
colossalai.launch(rank=rank, world_size=world_size, port=port, host='localhost')
# create plugin and objects for training
plugin = TorchDDPPlugin()

18
docs/source/zh-Hans/basics/launch_colossalai.md

@ -74,8 +74,7 @@ import colossalai
args = colossalai.get_default_parser().parse_args()
# launch distributed environment
colossalai.launch(config=args.config,
rank=args.rank,
colossalai.launch(rank=args.rank,
world_size=args.world_size,
host=args.host,
port=args.port,
@ -93,20 +92,11 @@ PyTorch自带的启动器需要在每个节点上都启动命令才能启动多
首先,我们需要在代码里指定我们的启动方式。由于这个启动器是PyTorch启动器的封装,那么我们自然而然应该使用`colossalai.launch_from_torch`。
分布式环境所需的参数,如 rank, world size, host 和 port 都是由 PyTorch 启动器设置的,可以直接从环境变量中读取。
config.py
```python
BATCH_SIZE = 512
LEARNING_RATE = 3e-3
WEIGHT_DECAY = 0.3
NUM_EPOCHS = 2
```
train.py
```python
import colossalai
colossalai.launch_from_torch(
config="./config.py",
)
colossalai.launch_from_torch()
...
```
@ -186,7 +176,6 @@ colossalai run --nproc_per_node 4 --hostfile ./hostfile --master_addr host1 --e
import colossalai
colossalai.launch_from_slurm(
config=<CONFIG>,
host=args.host,
port=args.port
)
@ -206,7 +195,6 @@ srun python train.py --host <master_node> --port 29500
您可以在您的训练脚本中尝试以下操作。
```python
colossalai.launch_from_openmpi(
config=<CONFIG>,
host=args.host,
port=args.port
)
@ -219,3 +207,5 @@ mpirun --hostfile <my_hostfile> -np <num_process> python train.py --host <node n
- --hostfile: 指定一个要运行的主机列表。
- --np: 设置总共要启动的进程(GPU)的数量。例如,如果 --np 4,4个 python 进程将被初始化以运行 train.py。
<!-- doc-test-command: echo -->

2
docs/source/zh-Hans/features/gradient_accumulation_with_booster.md

@ -46,7 +46,7 @@ parser = colossalai.get_default_parser()
args = parser.parse_args()
# launch from torch
colossalai.launch_from_torch(config=dict())
colossalai.launch_from_torch()
```

2
docs/source/zh-Hans/features/gradient_clipping_with_booster.md

@ -61,7 +61,7 @@ from colossalai.nn.lr_scheduler import CosineAnnealingLR
我们需要初始化分布式环境. 为了快速演示,我们使用`launch_from_torch`. 您可以参考 [Launch Colossal-AI](../basics/launch_colossalai.md)
```python
colossalai.launch_from_torch(config=dict())
colossalai.launch_from_torch()
logger = get_dist_logger()
```

2
docs/source/zh-Hans/features/lazy_init.md

@ -29,7 +29,7 @@ from colossalai.booster.plugin import GeminiPlugin
from transformers import LlamaForCausalLM, LlamaConfig, BertForPreTraining
colossalai.launch({})
colossalai.launch()
plugin = GeminiPlugin()
booster = Booster(plugin)

12
docs/source/zh-Hans/features/mixed_precision_training_with_booster.md

@ -19,11 +19,11 @@ AMP 代表自动混合精度训练。
2. apex.amp
3. naive amp
| Colossal-AI | 支持张量并行 | 支持流水并行 | fp16 范围 |
| -------------- | ------------ | ------------ | --------------------------------------------------------- |
| AMP_TYPE.TORCH | ✅ | ❌ | 在前向和反向传播期间,模型参数、激活和梯度向下转换至 fp16 |
| AMP_TYPE.APEX | ❌ | ❌ | 更细粒度,我们可以选择 opt_level O0, O1, O2, O3 |
| AMP_TYPE.NAIVE | ✅ | ✅ | 模型参数、前向和反向操作,全都向下转换至 fp16 |
| Colossal-AI | 支持张量并行 | 支持流水并行 | fp16 范围 |
|----------------|--------------|--------------|-------------------------------------------------------|
| AMP_TYPE.TORCH | ✅ | ❌ | 在前向和反向传播期间,模型参数、激活和梯度向下转换至 fp16 |
| AMP_TYPE.APEX | ❌ | ❌ | 更细粒度,我们可以选择 opt_level O0, O1, O2, O3 |
| AMP_TYPE.NAIVE | ✅ | ✅ | 模型参数、前向和反向操作,全都向下转换至 fp16 |
前两个依赖于 PyTorch (1.6 及以上) 和 NVIDIA Apex 的原始实现。最后一种方法类似 Apex O2。在这些方法中,Apex-AMP 与张量并行不兼容。这是因为张量是以张量并行的方式在设备之间拆分的,因此,需要在不同的进程之间进行通信,以检查整个模型权重中是否出现 inf 或 nan。我们修改了 torch amp 实现,使其现在与张量并行兼容。
@ -153,7 +153,7 @@ parser = colossalai.get_default_parser()
args = parser.parse_args()
# launch from torch
colossalai.launch_from_torch(config=dict())
colossalai.launch_from_torch()
```

2
docs/source/zh-Hans/features/nvme_offload.md

@ -175,7 +175,7 @@ Mem usage: 4968.016 MB
```python
def train_gemini_cpu(nvme_offload_fraction: float = 0.0):
colossalai.launch_from_torch({})
colossalai.launch_from_torch()
config = GPT2Config()
with ColoInitContext(device=torch.cuda.current_device()):
model = GPT2LMHeadModel(config)

2
docs/source/zh-Hans/features/zero_with_chunk.md

@ -174,7 +174,7 @@ def main():
SEQ_LEN = 1024
VOCAB_SIZE = 50257
NUM_STEPS = 10
colossalai.launch_from_torch(config={})
colossalai.launch_from_torch()
# build criterion
criterion = GPTLMLoss()

4
examples/community/roberta/pretraining/run_pretraining.py

@ -35,12 +35,12 @@ def main():
if args.vscode_debug:
colossalai.launch(
config={}, rank=args.rank, world_size=args.world_size, host=args.host, port=args.port, backend=args.backend
rank=args.rank, world_size=args.world_size, host=args.host, port=args.port, backend=args.backend
)
args.local_rank = -1
args.log_interval = 1
else:
colossalai.launch_from_torch(config={}) # args.colossal_config
colossalai.launch_from_torch() # args.colossal_config
args.local_rank = int(os.environ["LOCAL_RANK"])
logger.info(
f"launch_from_torch, world size: {torch.distributed.get_world_size()} | "

2
examples/images/dreambooth/debug.py

@ -9,7 +9,7 @@ from colossalai.zero import ColoInitContext
path = "/data/scratch/diffuser/stable-diffusion-v1-4"
colossalai.launch_from_torch(config={})
colossalai.launch_from_torch()
with ColoInitContext(device="cpu"):
vae = AutoencoderKL.from_pretrained(
path,

4
examples/images/dreambooth/train_dreambooth_colossalai.py

@ -372,9 +372,9 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token:
def main(args):
if args.seed is None:
colossalai.launch_from_torch(config={})
colossalai.launch_from_torch()
else:
colossalai.launch_from_torch(config={}, seed=args.seed)
colossalai.launch_from_torch(seed=args.seed)
local_rank = dist.get_rank()
world_size = dist.get_world_size()

4
examples/images/dreambooth/train_dreambooth_colossalai_lora.py

@ -371,9 +371,9 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token:
def main(args):
if args.seed is None:
colossalai.launch_from_torch(config={})
colossalai.launch_from_torch()
else:
colossalai.launch_from_torch(config={}, seed=args.seed)
colossalai.launch_from_torch(seed=args.seed)
local_rank = gpc.get_local_rank(ParallelMode.DATA)
world_size = gpc.get_world_size(ParallelMode.DATA)

2
examples/images/resnet/train.py

@ -128,7 +128,7 @@ def main():
# ==============================
# Launch Distributed Environment
# ==============================
colossalai.launch_from_torch(config={})
colossalai.launch_from_torch()
coordinator = DistCoordinator()
# update the learning rate with linear scaling

2
examples/images/vit/vit_benchmark.py

@ -46,7 +46,7 @@ def main():
args = parse_benchmark_args()
# Launch ColossalAI
colossalai.launch_from_torch(config={}, seed=args.seed)
colossalai.launch_from_torch(seed=args.seed)
coordinator = DistCoordinator()
world_size = coordinator.world_size

2
examples/images/vit/vit_train_demo.py

@ -137,7 +137,7 @@ def main():
args = parse_demo_args()
# Launch ColossalAI
colossalai.launch_from_torch(config={}, seed=args.seed)
colossalai.launch_from_torch(seed=args.seed)
coordinator = DistCoordinator()
world_size = coordinator.world_size

2
examples/inference/benchmark_llama.py

@ -136,7 +136,7 @@ def benchmark_inference(args):
def hybrid_inference(rank, world_size, port, args):
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
benchmark_inference(args)

2
examples/inference/run_llama_inference.py

@ -68,7 +68,7 @@ def run_inference(args):
def run_tp_pipeline_inference(rank, world_size, port, args):
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_inference(args)

2
examples/language/bert/benchmark.py

@ -81,7 +81,7 @@ def main():
# ==============================
# Launch Distributed Environment
# ==============================
colossalai.launch_from_torch(config={}, seed=42)
colossalai.launch_from_torch(seed=42)
coordinator = DistCoordinator()
# local_batch_size = BATCH_SIZE // coordinator.world_size

2
examples/language/bert/finetune.py

@ -202,7 +202,7 @@ def main():
# ==============================
# Launch Distributed Environment
# ==============================
colossalai.launch_from_torch(config={}, seed=42)
colossalai.launch_from_torch(seed=42)
coordinator = DistCoordinator()
lr = LEARNING_RATE * coordinator.world_size

3
examples/language/gpt/experiments/auto_offload/train_gpt_offload.py

@ -94,8 +94,7 @@ def train_gpt(args):
def run(rank, world_size, port, args):
config = {}
colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
train_gpt(args)

2
examples/language/gpt/experiments/auto_parallel/auto_parallel_with_gpt.py

@ -47,7 +47,7 @@ def get_data(batch_size, seq_len, vocab_size):
def main():
disable_existing_loggers()
launch_from_torch(config={})
launch_from_torch()
logger = get_dist_logger()
config = transformers.GPT2Config(n_position=SEQ_LENGTH, n_layer=NUM_LAYERS, n_head=NUM_HEADS, n_embd=HIDDEN_DIM)
if FP16:

2
examples/language/gpt/gemini/train_gpt_demo.py

@ -132,7 +132,7 @@ def main():
PROF_FLAG = False # The flag of profiling, False by default
disable_existing_loggers()
colossalai.launch_from_torch(config={})
colossalai.launch_from_torch()
logger = get_dist_logger()
logger.info(f"{args.model_type}, {args.distplan}, batch size {BATCH_SIZE}", ranks=[0])

2
examples/language/gpt/hybridparallelism/benchmark.py

@ -67,7 +67,7 @@ def main():
parser.add_argument("--cpu_offload", action="store_true", help="Use gradient checkpointing")
args = parser.parse_args()
colossalai.launch_from_torch({})
colossalai.launch_from_torch()
coordinator = DistCoordinator()
def empty_init():

2
examples/language/gpt/hybridparallelism/finetune.py

@ -196,7 +196,7 @@ def main():
# ==============================
# Launch Distributed Environment
# ==============================
colossalai.launch_from_torch(config={}, seed=42)
colossalai.launch_from_torch(seed=42)
coordinator = DistCoordinator()
# local_batch_size = BATCH_SIZE // coordinator.world_size

4
examples/language/gpt/titans/train_gpt.py

@ -36,9 +36,9 @@ def main():
args = parser.parse_args()
disable_existing_loggers()
if args.from_torch:
colossalai.launch_from_torch(config=args.config)
colossalai.launch_from_torch()
else:
colossalai.launch_from_slurm(config=args.config, host=args.host, port=29500, seed=42)
colossalai.launch_from_slurm(host=args.host, port=29500, seed=42)
logger = get_dist_logger()
data_path = None if args.use_dummy_dataset else os.environ["DATA"]

2
examples/language/grok-1/inference_tp.py

@ -16,7 +16,7 @@ if __name__ == "__main__":
parser = get_default_parser()
args = parser.parse_args()
start = time.time()
colossalai.launch_from_torch({})
colossalai.launch_from_torch()
coordinator = DistCoordinator()
plugin = HybridParallelPlugin(
tp_size=coordinator.world_size,

2
examples/language/llama/benchmark.py

@ -78,7 +78,7 @@ def main():
parser.add_argument("--custom-ckpt", action="store_true", help="Customize checkpoint", default=False)
args = parser.parse_args()
colossalai.launch_from_torch({})
colossalai.launch_from_torch()
coordinator = DistCoordinator()
def empty_init():

2
examples/language/openmoe/benchmark/benchmark_cai.py

@ -146,7 +146,7 @@ def main():
args = parse_args()
# Launch ColossalAI
colossalai.launch_from_torch(config={}, seed=args.seed)
colossalai.launch_from_torch(seed=args.seed)
coordinator = DistCoordinator()
# Set plugin

2
examples/language/openmoe/train.py

@ -207,7 +207,7 @@ def main():
args = parse_args()
# Launch ColossalAI
colossalai.launch_from_torch(config={}, seed=args.seed)
colossalai.launch_from_torch(seed=args.seed)
coordinator = DistCoordinator()
test_mode = args.model_name == "test"

2
examples/language/opt/opt_benchmark.py

@ -46,7 +46,7 @@ def main():
args = parse_benchmark_args()
# Launch ColossalAI
colossalai.launch_from_torch(config={}, seed=args.seed)
colossalai.launch_from_torch(seed=args.seed)
coordinator = DistCoordinator()
world_size = coordinator.world_size

2
examples/language/opt/opt_train_demo.py

@ -64,7 +64,7 @@ def main():
args = parse_demo_args()
# Launch ColossalAI
colossalai.launch_from_torch(config={}, seed=args.seed)
colossalai.launch_from_torch(seed=args.seed)
coordinator = DistCoordinator()
world_size = coordinator.world_size

2
examples/language/palm/train.py

@ -102,7 +102,7 @@ args = parse_args()
if args.distplan not in ["colossalai", "pytorch"]:
raise TypeError(f"{args.distplan} is error")
disable_existing_loggers()
colossalai.launch_from_torch(config={})
colossalai.launch_from_torch()
logger = get_dist_logger()

2
examples/tutorial/auto_parallel/auto_ckpt_batchsize_test.py

@ -20,7 +20,7 @@ def _benchmark(rank, world_size, port):
only result in minor performance drop. So at last we might be able to find better training batch size for our
model (combine with large batch training optimizer such as LAMB).
"""
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
model = tm.resnet152()
gm = symbolic_trace(model)
raw_graph = deepcopy(gm.graph)

2
examples/tutorial/auto_parallel/auto_ckpt_solver_test.py

@ -17,7 +17,7 @@ def _benchmark(rank, world_size, port, args):
The benchmark will sample in a range of memory budget for each model and output the benchmark summary and
data visualization of peak memory vs. budget memory and relative step time vs. peak memory.
"""
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
if args.model == "resnet50":
model = tm.resnet50()
data_gen = partial(data_gen_resnet, batch_size=128, shape=(3, 224, 224))

2
examples/tutorial/new_api/cifar_resnet/train.py

@ -128,7 +128,7 @@ def main():
# ==============================
# Launch Distributed Environment
# ==============================
colossalai.launch_from_torch(config={})
colossalai.launch_from_torch()
coordinator = DistCoordinator()
# update the learning rate with linear scaling

2
examples/tutorial/new_api/cifar_vit/train.py

@ -148,7 +148,7 @@ def main():
# ==============================
# Launch Distributed Environment
# ==============================
colossalai.launch_from_torch(config={})
colossalai.launch_from_torch()
coordinator = DistCoordinator()
# update the learning rate with linear scaling

2
examples/tutorial/new_api/glue_bert/finetune.py

@ -125,7 +125,7 @@ def main():
# ==============================
# Launch Distributed Environment
# ==============================
colossalai.launch_from_torch(config={}, seed=42)
colossalai.launch_from_torch(seed=42)
coordinator = DistCoordinator()
# local_batch_size = BATCH_SIZE // coordinator.world_size

2
examples/tutorial/opt/opt/run_clm.py

@ -289,7 +289,7 @@ class DummyDataloader:
def main():
args = parse_args()
disable_existing_loggers()
colossalai.legacy.launch_from_torch(config=dict())
colossalai.legacy.launch_from_torch()
logger = get_dist_logger()
is_main_process = dist.get_rank() == 0

2
tests/test_auto_parallel/test_ckpt_solvers/test_C_solver_consistency.py

@ -27,7 +27,7 @@ except:
def _run_C_solver_consistency_test(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
for M, mem_budget in [(tm.resnet50, 4000), (tm.densenet121, 8080)]:
model = M()

4
tests/test_auto_parallel/test_ckpt_solvers/test_ckpt_torchvision.py

@ -75,7 +75,7 @@ def check_backward_consistency(
def _run_ckpt_solver(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
MODEL_LIST = [tm.densenet121]
torch.backends.cudnn.deterministic = True
@ -111,7 +111,7 @@ def test_ckpt_solver():
def _run_ckpt_solver_torch11(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
MODEL_LIST = [tm.densenet121]
torch.backends.cudnn.deterministic = True

3
tests/test_auto_parallel/test_offload/test_perf.py

@ -141,8 +141,7 @@ def exam_fwd_bwd(model_name: str, memory_budget: float, solver_name: str):
def run_dist(rank, world_size, port):
config = {}
colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
exam_fwd_bwd()

4
tests/test_auto_parallel/test_tensor_shard/test_bias_addition_forward.py

@ -42,7 +42,7 @@ class ConvModel(torch.nn.Module):
def check_linear_module(rank, world_size, port):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
model = LinearModel(4, 8).cuda()
input = torch.rand(4, 4).cuda()
output_compare = model(input)
@ -59,7 +59,7 @@ def check_linear_module(rank, world_size, port):
def check_conv_module(rank, world_size, port):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
model = ConvModel(3, 6, 2).cuda()
input = torch.rand(4, 3, 64, 64).cuda()
output_compare = model(input)

2
tests/test_auto_parallel/test_tensor_shard/test_checkpoint.py

@ -39,7 +39,7 @@ class GPT2MLPWithCkpt(nn.Module):
def check_act_ckpt(rank, world_size, port):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
model = GPT2MLPWithCkpt(intermediate_size=4 * HIDDEN_SIZE, hidden_size=HIDDEN_SIZE)
torch.rand(1, 64, HIDDEN_SIZE)
input_sample = {

2
tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_ddp.py

@ -32,7 +32,7 @@ class MLP(torch.nn.Module):
def check_compatibility_with_ddp(rank, world_size, port):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
model = MLP(4).cuda()
if rank in [0, 1]:
input = torch.arange(0, 16, dtype=torch.float).reshape(4, 4).cuda()

2
tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py

@ -34,7 +34,7 @@ class MLP(torch.nn.Module):
def check_auto_parallel_with_gemini(rank, world_size, port):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
model = MLP(4).half().cuda()
if rank in [0, 1]:
input = torch.arange(0, 16).reshape(4, 4).half().cuda()

2
tests/test_auto_parallel/test_tensor_shard/test_gpt/test_runtime_with_gpt_modules.py

@ -73,7 +73,7 @@ def _check_module_grad(
def check_attention_layer(rank, model_cls, world_size, port):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
config = transformers.GPT2Config(n_position=64, n_layer=2, n_head=16, n_embd=HIDDEN_DIM)

2
tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_binary_elementwise_metainfo.py

@ -31,7 +31,7 @@ def _binary_elementwise_mem_test(rank, world_size, port):
port: port for initializing process group
"""
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
model = BinaryElementwiseOpModule(token=torch.add, shape=1024).cuda()
input = torch.rand(32, 1024).cuda()
input.requires_grad = True

4
tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_conv_metainfo.py

@ -31,7 +31,7 @@ def _conv_module_mem_test(rank, world_size, port, bias):
port: port for initializing process group
"""
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
model = nn.Sequential(nn.Conv2d(4, 64, 3, padding=1, bias=bias)).cuda()
input = torch.rand(4, 4, 64, 64).cuda()
input.requires_grad = True
@ -72,7 +72,7 @@ def _conv_function_mem_test(rank, world_size, port):
port: port for initializing process group
"""
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
model = ConvFunctionModule().cuda()
input = torch.rand(4, 4, 64, 64).cuda()
input.requires_grad = True

4
tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_linear_metainfo.py

@ -30,7 +30,7 @@ def _linear_module_mem_test(rank, world_size, port):
port: port for initializing process group
"""
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
model = nn.Sequential(nn.Linear(64, 128, bias=False)).cuda()
input = torch.rand(8, 8, 16, 64).cuda()
input.requires_grad = True
@ -68,7 +68,7 @@ def _linear_function_mem_test(rank, world_size, port):
port: port for initializing process group
"""
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
model = MyModule().cuda()
input = torch.rand(8, 8, 16, 64).cuda()
input.requires_grad = True

2
tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_norm_metainfo.py

@ -25,7 +25,7 @@ def _batchnorm_module_mem_test(rank, world_size, port):
port: port for initializing process group
"""
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
model = nn.Sequential(nn.BatchNorm2d(128)).cuda()
input = torch.rand(4, 128, 64, 64).cuda()
input.requires_grad = True

4
tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_pooling_metainfo.py

@ -21,7 +21,7 @@ def _adaptiveavgpool_module_mem_test(rank, world_size, port):
port: port for initializing process group
"""
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
model = nn.Sequential(nn.AdaptiveAvgPool2d((16, 16))).cuda()
input = torch.rand(4, 128, 64, 64).cuda()
input.requires_grad = True
@ -62,7 +62,7 @@ def _maxpool_module_mem_test(rank, world_size, port):
port: port for initializing process group
"""
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
model = nn.Sequential(nn.MaxPool2d((16, 16))).cuda()
input = torch.rand(4, 128, 64, 64).cuda()
input.requires_grad = True

4
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addbmm_handler.py

@ -40,7 +40,7 @@ class AddBMMTorchFunctionModule(nn.Module):
def check_2d_device_mesh(rank, world_size, port, module, bias_shape, using_kwargs):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
model = module(using_kwargs).cuda()
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
@ -150,7 +150,7 @@ def check_2d_device_mesh(rank, world_size, port, module, bias_shape, using_kwarg
def check_1d_device_mesh(rank, module, bias_shape, using_kwargs, world_size, port):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (1, 4)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)

2
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addmm_handler.py

@ -40,7 +40,7 @@ class AddmmModel_with_param(nn.Module):
def check_addmm_function_handler(rank, world_size, port, input_shape, model_cls):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
if model_cls == AddmmModel:
model = AddmmModel().cuda()
else:

2
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_batch_norm_handler.py

@ -16,7 +16,7 @@ from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import n
def check_bn_module_handler(rank, world_size, port):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
model = nn.Sequential(nn.BatchNorm2d(16)).cuda()
physical_mesh_id = torch.arange(0, 4)

2
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_function_node.py

@ -34,7 +34,7 @@ class LinearModule(torch.nn.Module):
def check_linear_module_handler(rank, world_size, port):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
model = LinearModule(weight_shape=WEIGHT_SHAPE).cuda()
physical_mesh_id = torch.arange(0, 4)

2
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_module_node.py

@ -30,7 +30,7 @@ class LinearModule(torch.nn.Module):
def check_linear_module_handler(rank, world_size, port, bias):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
model = LinearModule(16, 32, bias=bias).cuda()
physical_mesh_id = torch.arange(0, 4)

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save