[chatgpt] disable shard init for colossalai (#2767)

pull/2775/head
ver217 2023-02-16 20:09:34 +08:00 committed by GitHub
parent d6d6dec190
commit a88bc828d5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 7 additions and 1 deletions

View File

@ -1,3 +1,4 @@
import warnings
from typing import Optional
import torch
@ -23,6 +24,7 @@ class ColossalAIStrategy(DDPStrategy):
stage(int): The stage to use in ZeRO. Choose in (1, 2, 3)
seed(int): The seed for the random number generator.
shard_init(bool): Whether to shard the model parameters during initialization. Only for ZeRO-3.
This is not compativle with `from_pretrained()`. We temporarily disable this and will support it in the future.
placement_policy(str): The placement policy for gemini. Choose in ('cpu', 'cuda')
If it is cpu, parameters, gradients and optimizer states will be offloaded to CPU,
If it is cuda, they will not be offloaded, which means max CUDA memory will be used. It is the fastest.
@ -50,7 +52,7 @@ class ColossalAIStrategy(DDPStrategy):
self,
stage: int = 3,
seed: int = 42,
shard_init: bool = True, # only for stage 3
shard_init: bool = False, # only for stage 3
placement_policy: str = 'cuda',
pin_memory: bool = True, # only for stage 3
force_outputs_fp32: bool = False, # only for stage 3
@ -72,6 +74,10 @@ class ColossalAIStrategy(DDPStrategy):
super().__init__(seed)
assert placement_policy in ('cpu', 'cuda'), f'Unsupported placement policy "{placement_policy}"'
self.stage = stage
# TODO(ver217): support shard_init when using from_pretrained()
if shard_init:
warnings.warn(f'Shard init is not supported yet. Ignore.')
shard_init = False
self.shard_init = shard_init
self.gemini_config = dict(device=get_current_device(),
placement_policy=placement_policy,