Browse Source

[chatgpt] add precision option for colossalai (#3233)

pull/3221/head
ver217 2 years ago committed by GitHub
parent
commit
78fd31f9c1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 14
      applications/ChatGPT/chatgpt/trainer/strategies/colossalai.py

14
applications/ChatGPT/chatgpt/trainer/strategies/colossalai.py

@ -30,6 +30,7 @@ class ColossalAIStrategy(DDPStrategy):
Args: Args:
stage(int): The stage to use in ZeRO. Choose in (1, 2, 3) stage(int): The stage to use in ZeRO. Choose in (1, 2, 3)
precision(str): The precision to use. Choose in ('fp32', 'fp16'). Stage 3 only supports fp16.
seed(int): The seed for the random number generator. seed(int): The seed for the random number generator.
shard_init(bool): Whether to shard the model parameters during initialization. Only for ZeRO-3. shard_init(bool): Whether to shard the model parameters during initialization. Only for ZeRO-3.
This is not compativle with `from_pretrained()`. We temporarily disable this and will support it in the future. This is not compativle with `from_pretrained()`. We temporarily disable this and will support it in the future.
@ -59,6 +60,7 @@ class ColossalAIStrategy(DDPStrategy):
def __init__( def __init__(
self, self,
stage: int = 3, stage: int = 3,
precision: str = 'fp16',
seed: int = 42, seed: int = 42,
shard_init: bool = False, # only for stage 3 shard_init: bool = False, # only for stage 3
placement_policy: str = 'cuda', placement_policy: str = 'cuda',
@ -81,12 +83,17 @@ class ColossalAIStrategy(DDPStrategy):
norm_type: float = 2.0) -> None: norm_type: float = 2.0) -> None:
super().__init__(seed) super().__init__(seed)
assert placement_policy in ('cpu', 'cuda'), f'Unsupported placement policy "{placement_policy}"' assert placement_policy in ('cpu', 'cuda'), f'Unsupported placement policy "{placement_policy}"'
assert precision in ('fp32', 'fp16'), f'Unsupported precision "{precision}"'
self.stage = stage self.stage = stage
# TODO(ver217): support shard_init when using from_pretrained() # TODO(ver217): support shard_init when using from_pretrained()
if shard_init: if shard_init:
warnings.warn( warnings.warn(
f'Shard init is not supported model.from_pretrained() yet. Please load weights after strategy.prepare()' f'Shard init is not supported model.from_pretrained() yet. Please load weights after strategy.prepare()'
) )
if stage == 3 and precision == 'fp32':
warnings.warn(f'Stage 3 only supports fp16. Precision is set to fp16.')
precision = 'fp16'
self.precision = precision
self.shard_init = shard_init self.shard_init = shard_init
self.gemini_config = dict(device=get_current_device(), self.gemini_config = dict(device=get_current_device(),
placement_policy=placement_policy, placement_policy=placement_policy,
@ -127,7 +134,10 @@ class ColossalAIStrategy(DDPStrategy):
return super().model_init_context() return super().model_init_context()
def setup_model(self, model: nn.Module) -> nn.Module: def setup_model(self, model: nn.Module) -> nn.Module:
return zero_model_wrapper(model, zero_stage=self.stage, gemini_config=self.gemini_config) model = zero_model_wrapper(model, zero_stage=self.stage, gemini_config=self.gemini_config)
if self.stage != 3 and self.precision == 'fp16':
model = model.half()
return model
def setup_optimizer(self, optimizer: optim.Optimizer, model: nn.Module) -> optim.Optimizer: def setup_optimizer(self, optimizer: optim.Optimizer, model: nn.Module) -> optim.Optimizer:
assert isinstance(optimizer, (CPUAdam, HybridAdam)), f'Unsupported optimizer {type(optimizer)}' assert isinstance(optimizer, (CPUAdam, HybridAdam)), f'Unsupported optimizer {type(optimizer)}'
@ -159,7 +169,7 @@ class ColossalAIStrategy(DDPStrategy):
# merge lora_weights into weights # merge lora_weights into weights
for module in unwrapped_model.modules(): for module in unwrapped_model.modules():
if isinstance(module, LoraLinear): if isinstance(module, LoraLinear):
module.merge_weights=True module.merge_weights = True
module.eval() module.eval()
# get state_dict and save # get state_dict and save

Loading…
Cancel
Save