mirror of https://github.com/hpcaitech/ColossalAI
[Gemini] make gemini usage simple (#1821)
parent
99870726b1
commit
cd5a0d56fa
|
@ -1,3 +1,4 @@
|
|||
from .data_parallel import ColoDDP, ZeroDDP
|
||||
from .gemini_parallel import GeminiDDP
|
||||
|
||||
__all__ = ['ColoDDP', 'ZeroDDP']
|
||||
__all__ = ['ColoDDP', 'ZeroDDP', 'GeminiDDP']
|
||||
|
|
|
@ -188,25 +188,16 @@ class ColoDDP(torch.nn.Module):
|
|||
|
||||
|
||||
class ZeroDDP(ColoDDP):
|
||||
"""ZeRO-DP for ColoTensor. Nested ZeroDDP is not supported now.
|
||||
We can configure chunk and gemini via ChunkManager and GeminiManager respectively.
|
||||
"""ZeRO DDP for ColoTensor.
|
||||
Warning: Nested ZeroDDP is not supported now.
|
||||
It is designed to be used with ChunkManager and GeminiManager.
|
||||
For more details, see the API reference of ``ChunkManager`` and ``GeminiManager``.
|
||||
|
||||
Example:
|
||||
>>> model = torch.nn.Linear(20, 1)
|
||||
>>> placement_policy = 'cuda'
|
||||
>>> chunk_size = ChunkManager.search_chunk_size(model, search_range, n_grids) if use_chunk else None
|
||||
>>> chunk_manager = ChunkManager(chunk_size, enable_distributed_storage=use_zero, init_device=GeminiManager.get_default_device(placement_policy))
|
||||
>>> gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
||||
>>> model = ZeroDDP(model, gemini_manager)
|
||||
>>> logits = model(x)
|
||||
>>> loss = criterion(logits, labels)
|
||||
>>> model.backward(loss)
|
||||
|
||||
Args:
|
||||
module (torch.nn.Module): Module to apply ZeRO-DP.
|
||||
gemini_manager (GeminiManager): Manages the chunk manager and heterogeneous momery space.
|
||||
For more details, see the API reference of ``GeminiManager``.
|
||||
pin_memory (bool): Chunks on CPU Memory use pin-memory.
|
||||
force_outputs_fp32 (bool): If set to True, outputs will be fp32. Otherwise, outputs will be fp16. Defaults to False.
|
||||
"""
|
||||
|
||||
|
|
|
@ -0,0 +1,39 @@
|
|||
import torch
|
||||
|
||||
from colossalai.gemini.chunk import init_chunk_manager
|
||||
from colossalai.gemini.gemini_mgr import GeminiManager
|
||||
|
||||
from .data_parallel import ZeroDDP
|
||||
|
||||
|
||||
class GeminiDDP(ZeroDDP):
|
||||
|
||||
def __init__(self,
|
||||
module: torch.nn.Module,
|
||||
device: torch.device,
|
||||
placement_policy: str = "cpu",
|
||||
pin_memory: bool = False,
|
||||
force_outputs_fp32: bool = False,
|
||||
search_range_mb: int = 32) -> None:
|
||||
"""
|
||||
A torch.Module warpper using ZeRODPP and Genimi.
|
||||
ZeRO is for parallel. Gemini is for memory management.
|
||||
|
||||
Example:
|
||||
model is initialized under the context of ColoInitContext
|
||||
>>> model = GeminiDDP(model, torch.cuda.current_device(), "cuda")
|
||||
>>> logits = model(x)
|
||||
>>> loss = criterion(logits, labels)
|
||||
>>> model.backward(loss)
|
||||
|
||||
Args:
|
||||
module (torch.nn.Module): the model to be wrapped.
|
||||
device (torch.device): device to place the model.
|
||||
placement_policy (str, optional): "cpu", "cuda", "auto". Defaults to "cpu".
|
||||
pin_memory (bool, optional): use pin memory on CPU. Defaults to False.
|
||||
force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False.
|
||||
search_range_mb (int, optional): chunk size searching range in MegaByte. Defaults to 32.
|
||||
"""
|
||||
chunk_manager = init_chunk_manager(model=module, init_device=device, search_range_mb=search_range_mb)
|
||||
gemini_manager = GeminiManager(placement_policy, chunk_manager, module)
|
||||
super().__init__(module, gemini_manager, pin_memory, force_outputs_fp32)
|
|
@ -24,7 +24,6 @@ https://huggingface.co/models?filter=text-generation
|
|||
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from itertools import chain
|
||||
|
||||
|
@ -43,7 +42,6 @@ import colossalai
|
|||
import transformers
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.gemini import ChunkManager, GeminiManager
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.nn.parallel import ZeroDDP
|
||||
|
@ -380,11 +378,8 @@ def main():
|
|||
cai_version = colossalai.__version__
|
||||
logger.info(f'using Colossal-AI version {cai_version}')
|
||||
if version.parse(cai_version) > version.parse("0.1.10"):
|
||||
from colossalai.gemini import GeminiManager
|
||||
from colossalai.gemini.chunk import init_chunk_manager
|
||||
chunk_manager = init_chunk_manager(model=model, init_device=get_current_device(), search_range_mb=32)
|
||||
gemini_manager = GeminiManager(PLACEMENT_POLICY, chunk_manager)
|
||||
model = ZeroDDP(model, gemini_manager, pin_memory=True)
|
||||
from colossalai.nn.parallel import GeminiDDP
|
||||
model = GeminiDDP(model, device=get_current_device(), placement_policy=PLACEMENT_POLICY, pin_memory=True)
|
||||
elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"):
|
||||
from colossalai.gemini import ChunkManager, GeminiManager
|
||||
pg = ProcessGroup()
|
||||
|
@ -393,6 +388,8 @@ def main():
|
|||
pg,
|
||||
enable_distributed_storage=True,
|
||||
init_device=GeminiManager.get_default_device(PLACEMENT_POLICY))
|
||||
gemini_manager = GeminiManager(PLACEMENT_POLICY, chunk_manager)
|
||||
model = ZeroDDP(model, gemini_manager)
|
||||
|
||||
logger.info(f'{model.__class__.__name__} has been created', ranks=[0])
|
||||
|
||||
|
|
Loading…
Reference in New Issue