mirror of https://github.com/hpcaitech/ColossalAI
[Gemini] make gemini usage simple (#1821)
parent
99870726b1
commit
cd5a0d56fa
|
@ -1,3 +1,4 @@
|
||||||
from .data_parallel import ColoDDP, ZeroDDP
|
from .data_parallel import ColoDDP, ZeroDDP
|
||||||
|
from .gemini_parallel import GeminiDDP
|
||||||
|
|
||||||
__all__ = ['ColoDDP', 'ZeroDDP']
|
__all__ = ['ColoDDP', 'ZeroDDP', 'GeminiDDP']
|
||||||
|
|
|
@ -188,25 +188,16 @@ class ColoDDP(torch.nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class ZeroDDP(ColoDDP):
|
class ZeroDDP(ColoDDP):
|
||||||
"""ZeRO-DP for ColoTensor. Nested ZeroDDP is not supported now.
|
"""ZeRO DDP for ColoTensor.
|
||||||
We can configure chunk and gemini via ChunkManager and GeminiManager respectively.
|
Warning: Nested ZeroDDP is not supported now.
|
||||||
|
It is designed to be used with ChunkManager and GeminiManager.
|
||||||
For more details, see the API reference of ``ChunkManager`` and ``GeminiManager``.
|
For more details, see the API reference of ``ChunkManager`` and ``GeminiManager``.
|
||||||
|
|
||||||
Example:
|
|
||||||
>>> model = torch.nn.Linear(20, 1)
|
|
||||||
>>> placement_policy = 'cuda'
|
|
||||||
>>> chunk_size = ChunkManager.search_chunk_size(model, search_range, n_grids) if use_chunk else None
|
|
||||||
>>> chunk_manager = ChunkManager(chunk_size, enable_distributed_storage=use_zero, init_device=GeminiManager.get_default_device(placement_policy))
|
|
||||||
>>> gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
|
||||||
>>> model = ZeroDDP(model, gemini_manager)
|
|
||||||
>>> logits = model(x)
|
|
||||||
>>> loss = criterion(logits, labels)
|
|
||||||
>>> model.backward(loss)
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
module (torch.nn.Module): Module to apply ZeRO-DP.
|
module (torch.nn.Module): Module to apply ZeRO-DP.
|
||||||
gemini_manager (GeminiManager): Manages the chunk manager and heterogeneous momery space.
|
gemini_manager (GeminiManager): Manages the chunk manager and heterogeneous momery space.
|
||||||
For more details, see the API reference of ``GeminiManager``.
|
For more details, see the API reference of ``GeminiManager``.
|
||||||
|
pin_memory (bool): Chunks on CPU Memory use pin-memory.
|
||||||
force_outputs_fp32 (bool): If set to True, outputs will be fp32. Otherwise, outputs will be fp16. Defaults to False.
|
force_outputs_fp32 (bool): If set to True, outputs will be fp32. Otherwise, outputs will be fp16. Defaults to False.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,39 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from colossalai.gemini.chunk import init_chunk_manager
|
||||||
|
from colossalai.gemini.gemini_mgr import GeminiManager
|
||||||
|
|
||||||
|
from .data_parallel import ZeroDDP
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiDDP(ZeroDDP):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
module: torch.nn.Module,
|
||||||
|
device: torch.device,
|
||||||
|
placement_policy: str = "cpu",
|
||||||
|
pin_memory: bool = False,
|
||||||
|
force_outputs_fp32: bool = False,
|
||||||
|
search_range_mb: int = 32) -> None:
|
||||||
|
"""
|
||||||
|
A torch.Module warpper using ZeRODPP and Genimi.
|
||||||
|
ZeRO is for parallel. Gemini is for memory management.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
model is initialized under the context of ColoInitContext
|
||||||
|
>>> model = GeminiDDP(model, torch.cuda.current_device(), "cuda")
|
||||||
|
>>> logits = model(x)
|
||||||
|
>>> loss = criterion(logits, labels)
|
||||||
|
>>> model.backward(loss)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module (torch.nn.Module): the model to be wrapped.
|
||||||
|
device (torch.device): device to place the model.
|
||||||
|
placement_policy (str, optional): "cpu", "cuda", "auto". Defaults to "cpu".
|
||||||
|
pin_memory (bool, optional): use pin memory on CPU. Defaults to False.
|
||||||
|
force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False.
|
||||||
|
search_range_mb (int, optional): chunk size searching range in MegaByte. Defaults to 32.
|
||||||
|
"""
|
||||||
|
chunk_manager = init_chunk_manager(model=module, init_device=device, search_range_mb=search_range_mb)
|
||||||
|
gemini_manager = GeminiManager(placement_policy, chunk_manager, module)
|
||||||
|
super().__init__(module, gemini_manager, pin_memory, force_outputs_fp32)
|
|
@ -24,7 +24,6 @@ https://huggingface.co/models?filter=text-generation
|
||||||
|
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import random
|
|
||||||
import time
|
import time
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
|
|
||||||
|
@ -43,7 +42,6 @@ import colossalai
|
||||||
import transformers
|
import transformers
|
||||||
from colossalai.context import ParallelMode
|
from colossalai.context import ParallelMode
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.gemini import ChunkManager, GeminiManager
|
|
||||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||||
from colossalai.nn.optimizer import HybridAdam
|
from colossalai.nn.optimizer import HybridAdam
|
||||||
from colossalai.nn.parallel import ZeroDDP
|
from colossalai.nn.parallel import ZeroDDP
|
||||||
|
@ -380,11 +378,8 @@ def main():
|
||||||
cai_version = colossalai.__version__
|
cai_version = colossalai.__version__
|
||||||
logger.info(f'using Colossal-AI version {cai_version}')
|
logger.info(f'using Colossal-AI version {cai_version}')
|
||||||
if version.parse(cai_version) > version.parse("0.1.10"):
|
if version.parse(cai_version) > version.parse("0.1.10"):
|
||||||
from colossalai.gemini import GeminiManager
|
from colossalai.nn.parallel import GeminiDDP
|
||||||
from colossalai.gemini.chunk import init_chunk_manager
|
model = GeminiDDP(model, device=get_current_device(), placement_policy=PLACEMENT_POLICY, pin_memory=True)
|
||||||
chunk_manager = init_chunk_manager(model=model, init_device=get_current_device(), search_range_mb=32)
|
|
||||||
gemini_manager = GeminiManager(PLACEMENT_POLICY, chunk_manager)
|
|
||||||
model = ZeroDDP(model, gemini_manager, pin_memory=True)
|
|
||||||
elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"):
|
elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"):
|
||||||
from colossalai.gemini import ChunkManager, GeminiManager
|
from colossalai.gemini import ChunkManager, GeminiManager
|
||||||
pg = ProcessGroup()
|
pg = ProcessGroup()
|
||||||
|
@ -393,6 +388,8 @@ def main():
|
||||||
pg,
|
pg,
|
||||||
enable_distributed_storage=True,
|
enable_distributed_storage=True,
|
||||||
init_device=GeminiManager.get_default_device(PLACEMENT_POLICY))
|
init_device=GeminiManager.get_default_device(PLACEMENT_POLICY))
|
||||||
|
gemini_manager = GeminiManager(PLACEMENT_POLICY, chunk_manager)
|
||||||
|
model = ZeroDDP(model, gemini_manager)
|
||||||
|
|
||||||
logger.info(f'{model.__class__.__name__} has been created', ranks=[0])
|
logger.info(f'{model.__class__.__name__} has been created', ranks=[0])
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue