From 78fd31f9c15b698a4ed07748096684fa40bbc11a Mon Sep 17 00:00:00 2001 From: ver217 Date: Fri, 24 Mar 2023 12:15:06 +0800 Subject: [PATCH] [chatgpt] add precision option for colossalai (#3233) --- .../chatgpt/trainer/strategies/colossalai.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/applications/ChatGPT/chatgpt/trainer/strategies/colossalai.py b/applications/ChatGPT/chatgpt/trainer/strategies/colossalai.py index f11dc6f75..0a7c91732 100644 --- a/applications/ChatGPT/chatgpt/trainer/strategies/colossalai.py +++ b/applications/ChatGPT/chatgpt/trainer/strategies/colossalai.py @@ -30,6 +30,7 @@ class ColossalAIStrategy(DDPStrategy): Args: stage(int): The stage to use in ZeRO. Choose in (1, 2, 3) + precision(str): The precision to use. Choose in ('fp32', 'fp16'). Stage 3 only supports fp16. seed(int): The seed for the random number generator. shard_init(bool): Whether to shard the model parameters during initialization. Only for ZeRO-3. This is not compativle with `from_pretrained()`. We temporarily disable this and will support it in the future. @@ -59,6 +60,7 @@ class ColossalAIStrategy(DDPStrategy): def __init__( self, stage: int = 3, + precision: str = 'fp16', seed: int = 42, shard_init: bool = False, # only for stage 3 placement_policy: str = 'cuda', @@ -81,12 +83,17 @@ class ColossalAIStrategy(DDPStrategy): norm_type: float = 2.0) -> None: super().__init__(seed) assert placement_policy in ('cpu', 'cuda'), f'Unsupported placement policy "{placement_policy}"' + assert precision in ('fp32', 'fp16'), f'Unsupported precision "{precision}"' self.stage = stage # TODO(ver217): support shard_init when using from_pretrained() if shard_init: warnings.warn( f'Shard init is not supported model.from_pretrained() yet. Please load weights after strategy.prepare()' ) + if stage == 3 and precision == 'fp32': + warnings.warn(f'Stage 3 only supports fp16. Precision is set to fp16.') + precision = 'fp16' + self.precision = precision self.shard_init = shard_init self.gemini_config = dict(device=get_current_device(), placement_policy=placement_policy, @@ -127,7 +134,10 @@ class ColossalAIStrategy(DDPStrategy): return super().model_init_context() def setup_model(self, model: nn.Module) -> nn.Module: - return zero_model_wrapper(model, zero_stage=self.stage, gemini_config=self.gemini_config) + model = zero_model_wrapper(model, zero_stage=self.stage, gemini_config=self.gemini_config) + if self.stage != 3 and self.precision == 'fp16': + model = model.half() + return model def setup_optimizer(self, optimizer: optim.Optimizer, model: nn.Module) -> optim.Optimizer: assert isinstance(optimizer, (CPUAdam, HybridAdam)), f'Unsupported optimizer {type(optimizer)}' @@ -159,7 +169,7 @@ class ColossalAIStrategy(DDPStrategy): # merge lora_weights into weights for module in unwrapped_model.modules(): if isinstance(module, LoraLinear): - module.merge_weights=True + module.merge_weights = True module.eval() # get state_dict and save