mirror of https://github.com/hpcaitech/ColossalAI
[chatgpt] add precision option for colossalai (#3233)
parent
bd39877da4
commit
78fd31f9c1
|
@ -30,6 +30,7 @@ class ColossalAIStrategy(DDPStrategy):
|
|||
|
||||
Args:
|
||||
stage(int): The stage to use in ZeRO. Choose in (1, 2, 3)
|
||||
precision(str): The precision to use. Choose in ('fp32', 'fp16'). Stage 3 only supports fp16.
|
||||
seed(int): The seed for the random number generator.
|
||||
shard_init(bool): Whether to shard the model parameters during initialization. Only for ZeRO-3.
|
||||
This is not compativle with `from_pretrained()`. We temporarily disable this and will support it in the future.
|
||||
|
@ -59,6 +60,7 @@ class ColossalAIStrategy(DDPStrategy):
|
|||
def __init__(
|
||||
self,
|
||||
stage: int = 3,
|
||||
precision: str = 'fp16',
|
||||
seed: int = 42,
|
||||
shard_init: bool = False, # only for stage 3
|
||||
placement_policy: str = 'cuda',
|
||||
|
@ -81,12 +83,17 @@ class ColossalAIStrategy(DDPStrategy):
|
|||
norm_type: float = 2.0) -> None:
|
||||
super().__init__(seed)
|
||||
assert placement_policy in ('cpu', 'cuda'), f'Unsupported placement policy "{placement_policy}"'
|
||||
assert precision in ('fp32', 'fp16'), f'Unsupported precision "{precision}"'
|
||||
self.stage = stage
|
||||
# TODO(ver217): support shard_init when using from_pretrained()
|
||||
if shard_init:
|
||||
warnings.warn(
|
||||
f'Shard init is not supported model.from_pretrained() yet. Please load weights after strategy.prepare()'
|
||||
)
|
||||
if stage == 3 and precision == 'fp32':
|
||||
warnings.warn(f'Stage 3 only supports fp16. Precision is set to fp16.')
|
||||
precision = 'fp16'
|
||||
self.precision = precision
|
||||
self.shard_init = shard_init
|
||||
self.gemini_config = dict(device=get_current_device(),
|
||||
placement_policy=placement_policy,
|
||||
|
@ -127,7 +134,10 @@ class ColossalAIStrategy(DDPStrategy):
|
|||
return super().model_init_context()
|
||||
|
||||
def setup_model(self, model: nn.Module) -> nn.Module:
|
||||
return zero_model_wrapper(model, zero_stage=self.stage, gemini_config=self.gemini_config)
|
||||
model = zero_model_wrapper(model, zero_stage=self.stage, gemini_config=self.gemini_config)
|
||||
if self.stage != 3 and self.precision == 'fp16':
|
||||
model = model.half()
|
||||
return model
|
||||
|
||||
def setup_optimizer(self, optimizer: optim.Optimizer, model: nn.Module) -> optim.Optimizer:
|
||||
assert isinstance(optimizer, (CPUAdam, HybridAdam)), f'Unsupported optimizer {type(optimizer)}'
|
||||
|
@ -159,7 +169,7 @@ class ColossalAIStrategy(DDPStrategy):
|
|||
# merge lora_weights into weights
|
||||
for module in unwrapped_model.modules():
|
||||
if isinstance(module, LoraLinear):
|
||||
module.merge_weights=True
|
||||
module.merge_weights = True
|
||||
module.eval()
|
||||
# get state_dict and save
|
||||
|
||||
|
|
Loading…
Reference in New Issue