diff --git a/docs/source/en/advanced_tutorials/parallelize_your_training_like_Megatron.md b/docs/source/en/advanced_tutorials/parallelize_your_training_like_Megatron.md index 1a7ab9a65..22d52fb3c 100644 --- a/docs/source/en/advanced_tutorials/parallelize_your_training_like_Megatron.md +++ b/docs/source/en/advanced_tutorials/parallelize_your_training_like_Megatron.md @@ -175,11 +175,11 @@ In this way, users can train their models as usual. In our latest example, a Gemini + ZeRO DDP model is also defined to reduce overhead and improve efficiency.For the details of this part, please refer to [ZeRO](../features/zero_with_chunk.md). You can combine these two parts to understand our entire training process: ```python -def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy: str = "auto"): +def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placement_policy: str = "auto"): from colossalai.nn.parallel import GeminiDDP model = GeminiDDP(model, device=get_current_device(), - placement_policy=placememt_policy, + placement_policy=placement_policy, pin_memory=True, search_range_mb=32) return model diff --git a/docs/source/en/features/zero_with_chunk.md b/docs/source/en/features/zero_with_chunk.md index 8448c52ac..d7a99f2fb 100644 --- a/docs/source/en/features/zero_with_chunk.md +++ b/docs/source/en/features/zero_with_chunk.md @@ -185,23 +185,23 @@ def split_param_col_tp1d(param: ColoParameter, pg: ProcessGroup): Define a model which uses Gemini + ZeRO DDP: ```python -def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy: str = "auto"): +def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placement_policy: str = "auto"): cai_version = colossalai.__version__ if version.parse(cai_version) > version.parse("0.1.10"): from colossalai.nn.parallel import GeminiDDP model = GeminiDDP(model, device=get_current_device(), - placement_policy=placememt_policy, + placement_policy=placement_policy, pin_memory=True, search_range_mb=32) elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"): from colossalai.gemini import ChunkManager, GeminiManager chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 32) - gemini_manager = GeminiManager(placememt_policy, chunk_manager) + gemini_manager = GeminiManager(placement_policy, chunk_manager) chunk_manager = ChunkManager(chunk_size, pg, enable_distributed_storage=True, - init_device=GeminiManager.get_default_device(placememt_policy)) + init_device=GeminiManager.get_default_device(placement_policy)) model = ZeroDDP(model, gemini_manager) else: raise NotImplemented(f"CAI version {cai_version} is not supported") diff --git a/docs/source/zh-Hans/advanced_tutorials/parallelize_your_training_like_Megatron.md b/docs/source/zh-Hans/advanced_tutorials/parallelize_your_training_like_Megatron.md index f3c6247c3..c4131e593 100644 --- a/docs/source/zh-Hans/advanced_tutorials/parallelize_your_training_like_Megatron.md +++ b/docs/source/zh-Hans/advanced_tutorials/parallelize_your_training_like_Megatron.md @@ -159,11 +159,11 @@ for mn, module in model.named_modules(): 在我们最新示例中还定义了一个Gemini + ZeRO DDP 的模型从而减小开销,提升效率。这一部分的详细内容可以参考[ZeRO](../features/zero_with_chunk.md),你可以将这两部分内容结合起来看从而理解我们整个训练流程: ```python -def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy: str = "auto"): +def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placement_policy: str = "auto"): from colossalai.nn.parallel import GeminiDDP model = GeminiDDP(model, device=get_current_device(), - placement_policy=placememt_policy, + placement_policy=placement_policy, pin_memory=True, search_range_mb=32) return model diff --git a/docs/source/zh-Hans/features/zero_with_chunk.md b/docs/source/zh-Hans/features/zero_with_chunk.md index 72403bf61..ba57ba4e8 100644 --- a/docs/source/zh-Hans/features/zero_with_chunk.md +++ b/docs/source/zh-Hans/features/zero_with_chunk.md @@ -185,23 +185,23 @@ def split_param_col_tp1d(param: ColoParameter, pg: ProcessGroup): 定义一个使用 Gemini + ZeRO DDP 的模型: ```python -def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy: str = "auto"): +def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placement_policy: str = "auto"): cai_version = colossalai.__version__ if version.parse(cai_version) > version.parse("0.1.10"): from colossalai.nn.parallel import GeminiDDP model = GeminiDDP(model, device=get_current_device(), - placement_policy=placememt_policy, + placement_policy=placement_policy, pin_memory=True, search_range_mb=32) elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"): from colossalai.gemini import ChunkManager, GeminiManager chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 32) - gemini_manager = GeminiManager(placememt_policy, chunk_manager) + gemini_manager = GeminiManager(placement_policy, chunk_manager) chunk_manager = ChunkManager(chunk_size, pg, enable_distributed_storage=True, - init_device=GeminiManager.get_default_device(placememt_policy)) + init_device=GeminiManager.get_default_device(placement_policy)) model = ZeroDDP(model, gemini_manager) else: raise NotImplemented(f"CAI version {cai_version} is not supported") diff --git a/examples/images/dreambooth/train_dreambooth_colossalai.py b/examples/images/dreambooth/train_dreambooth_colossalai.py index e6159e105..d07febea0 100644 --- a/examples/images/dreambooth/train_dreambooth_colossalai.py +++ b/examples/images/dreambooth/train_dreambooth_colossalai.py @@ -340,12 +340,12 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: # Gemini + ZeRO DDP -def gemini_zero_dpp(model: torch.nn.Module, placememt_policy: str = "auto"): +def gemini_zero_dpp(model: torch.nn.Module, placement_policy: str = "auto"): from colossalai.nn.parallel import GeminiDDP model = GeminiDDP(model, device=get_current_device(), - placement_policy=placememt_policy, + placement_policy=placement_policy, pin_memory=True, search_range_mb=64) return model diff --git a/examples/images/dreambooth/train_dreambooth_colossalai_lora.py b/examples/images/dreambooth/train_dreambooth_colossalai_lora.py index 1b2fc778d..6715b473a 100644 --- a/examples/images/dreambooth/train_dreambooth_colossalai_lora.py +++ b/examples/images/dreambooth/train_dreambooth_colossalai_lora.py @@ -342,12 +342,12 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: # Gemini + ZeRO DDP -def gemini_zero_dpp(model: torch.nn.Module, placememt_policy: str = "auto"): +def gemini_zero_dpp(model: torch.nn.Module, placement_policy: str = "auto"): from colossalai.nn.parallel import GeminiDDP model = GeminiDDP(model, device=get_current_device(), - placement_policy=placememt_policy, + placement_policy=placement_policy, pin_memory=True, search_range_mb=64) return model diff --git a/examples/language/palm/train.py b/examples/language/palm/train.py index 7923e4fc8..b16da1c77 100644 --- a/examples/language/palm/train.py +++ b/examples/language/palm/train.py @@ -102,23 +102,23 @@ def get_model_size(model: nn.Module): # Gemini + ZeRO DDP -def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy: str = "auto"): +def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placement_policy: str = "auto"): cai_version = colossalai.__version__ if version.parse(cai_version) > version.parse("0.1.10"): from colossalai.nn.parallel import GeminiDDP model = GeminiDDP(model, device=get_current_device(), - placement_policy=placememt_policy, + placement_policy=placement_policy, pin_memory=True, search_range_mb=32) elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"): from colossalai.gemini import ChunkManager, GeminiManager chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 32) - gemini_manager = GeminiManager(placememt_policy, chunk_manager) + gemini_manager = GeminiManager(placement_policy, chunk_manager) chunk_manager = ChunkManager(chunk_size, pg, enable_distributed_storage=True, - init_device=GeminiManager.get_default_device(placememt_policy)) + init_device=GeminiManager.get_default_device(placement_policy)) model = ZeroDDP(model, gemini_manager) else: raise NotImplemented(f"CAI version {cai_version} is not supported")