[Gemini] make gemini usage simple (#1821)

pull/1826/head^2
Jiarui Fang 2022-11-08 15:53:13 +08:00 committed by GitHub
parent 99870726b1
commit cd5a0d56fa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 49 additions and 21 deletions

View File

@ -1,3 +1,4 @@
from .data_parallel import ColoDDP, ZeroDDP
from .gemini_parallel import GeminiDDP
__all__ = ['ColoDDP', 'ZeroDDP']
__all__ = ['ColoDDP', 'ZeroDDP', 'GeminiDDP']

View File

@ -188,25 +188,16 @@ class ColoDDP(torch.nn.Module):
class ZeroDDP(ColoDDP):
"""ZeRO-DP for ColoTensor. Nested ZeroDDP is not supported now.
We can configure chunk and gemini via ChunkManager and GeminiManager respectively.
"""ZeRO DDP for ColoTensor.
Warning: Nested ZeroDDP is not supported now.
It is designed to be used with ChunkManager and GeminiManager.
For more details, see the API reference of ``ChunkManager`` and ``GeminiManager``.
Example:
>>> model = torch.nn.Linear(20, 1)
>>> placement_policy = 'cuda'
>>> chunk_size = ChunkManager.search_chunk_size(model, search_range, n_grids) if use_chunk else None
>>> chunk_manager = ChunkManager(chunk_size, enable_distributed_storage=use_zero, init_device=GeminiManager.get_default_device(placement_policy))
>>> gemini_manager = GeminiManager(placement_policy, chunk_manager)
>>> model = ZeroDDP(model, gemini_manager)
>>> logits = model(x)
>>> loss = criterion(logits, labels)
>>> model.backward(loss)
Args:
module (torch.nn.Module): Module to apply ZeRO-DP.
gemini_manager (GeminiManager): Manages the chunk manager and heterogeneous momery space.
For more details, see the API reference of ``GeminiManager``.
pin_memory (bool): Chunks on CPU Memory use pin-memory.
force_outputs_fp32 (bool): If set to True, outputs will be fp32. Otherwise, outputs will be fp16. Defaults to False.
"""

View File

@ -0,0 +1,39 @@
import torch
from colossalai.gemini.chunk import init_chunk_manager
from colossalai.gemini.gemini_mgr import GeminiManager
from .data_parallel import ZeroDDP
class GeminiDDP(ZeroDDP):
def __init__(self,
module: torch.nn.Module,
device: torch.device,
placement_policy: str = "cpu",
pin_memory: bool = False,
force_outputs_fp32: bool = False,
search_range_mb: int = 32) -> None:
"""
A torch.Module warpper using ZeRODPP and Genimi.
ZeRO is for parallel. Gemini is for memory management.
Example:
model is initialized under the context of ColoInitContext
>>> model = GeminiDDP(model, torch.cuda.current_device(), "cuda")
>>> logits = model(x)
>>> loss = criterion(logits, labels)
>>> model.backward(loss)
Args:
module (torch.nn.Module): the model to be wrapped.
device (torch.device): device to place the model.
placement_policy (str, optional): "cpu", "cuda", "auto". Defaults to "cpu".
pin_memory (bool, optional): use pin memory on CPU. Defaults to False.
force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False.
search_range_mb (int, optional): chunk size searching range in MegaByte. Defaults to 32.
"""
chunk_manager = init_chunk_manager(model=module, init_device=device, search_range_mb=search_range_mb)
gemini_manager = GeminiManager(placement_policy, chunk_manager, module)
super().__init__(module, gemini_manager, pin_memory, force_outputs_fp32)

View File

@ -24,7 +24,6 @@ https://huggingface.co/models?filter=text-generation
import math
import os
import random
import time
from itertools import chain
@ -43,7 +42,6 @@ import colossalai
import transformers
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.gemini import ChunkManager, GeminiManager
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.parallel import ZeroDDP
@ -380,11 +378,8 @@ def main():
cai_version = colossalai.__version__
logger.info(f'using Colossal-AI version {cai_version}')
if version.parse(cai_version) > version.parse("0.1.10"):
from colossalai.gemini import GeminiManager
from colossalai.gemini.chunk import init_chunk_manager
chunk_manager = init_chunk_manager(model=model, init_device=get_current_device(), search_range_mb=32)
gemini_manager = GeminiManager(PLACEMENT_POLICY, chunk_manager)
model = ZeroDDP(model, gemini_manager, pin_memory=True)
from colossalai.nn.parallel import GeminiDDP
model = GeminiDDP(model, device=get_current_device(), placement_policy=PLACEMENT_POLICY, pin_memory=True)
elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"):
from colossalai.gemini import ChunkManager, GeminiManager
pg = ProcessGroup()
@ -393,6 +388,8 @@ def main():
pg,
enable_distributed_storage=True,
init_device=GeminiManager.get_default_device(PLACEMENT_POLICY))
gemini_manager = GeminiManager(PLACEMENT_POLICY, chunk_manager)
model = ZeroDDP(model, gemini_manager)
logger.info(f'{model.__class__.__name__} has been created', ranks=[0])