mirror of https://github.com/hpcaitech/ColossalAI
[chatgpt] add precision option for colossalai (#3233)
parent
bd39877da4
commit
78fd31f9c1
|
@ -30,6 +30,7 @@ class ColossalAIStrategy(DDPStrategy):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
stage(int): The stage to use in ZeRO. Choose in (1, 2, 3)
|
stage(int): The stage to use in ZeRO. Choose in (1, 2, 3)
|
||||||
|
precision(str): The precision to use. Choose in ('fp32', 'fp16'). Stage 3 only supports fp16.
|
||||||
seed(int): The seed for the random number generator.
|
seed(int): The seed for the random number generator.
|
||||||
shard_init(bool): Whether to shard the model parameters during initialization. Only for ZeRO-3.
|
shard_init(bool): Whether to shard the model parameters during initialization. Only for ZeRO-3.
|
||||||
This is not compativle with `from_pretrained()`. We temporarily disable this and will support it in the future.
|
This is not compativle with `from_pretrained()`. We temporarily disable this and will support it in the future.
|
||||||
|
@ -59,6 +60,7 @@ class ColossalAIStrategy(DDPStrategy):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
stage: int = 3,
|
stage: int = 3,
|
||||||
|
precision: str = 'fp16',
|
||||||
seed: int = 42,
|
seed: int = 42,
|
||||||
shard_init: bool = False, # only for stage 3
|
shard_init: bool = False, # only for stage 3
|
||||||
placement_policy: str = 'cuda',
|
placement_policy: str = 'cuda',
|
||||||
|
@ -81,12 +83,17 @@ class ColossalAIStrategy(DDPStrategy):
|
||||||
norm_type: float = 2.0) -> None:
|
norm_type: float = 2.0) -> None:
|
||||||
super().__init__(seed)
|
super().__init__(seed)
|
||||||
assert placement_policy in ('cpu', 'cuda'), f'Unsupported placement policy "{placement_policy}"'
|
assert placement_policy in ('cpu', 'cuda'), f'Unsupported placement policy "{placement_policy}"'
|
||||||
|
assert precision in ('fp32', 'fp16'), f'Unsupported precision "{precision}"'
|
||||||
self.stage = stage
|
self.stage = stage
|
||||||
# TODO(ver217): support shard_init when using from_pretrained()
|
# TODO(ver217): support shard_init when using from_pretrained()
|
||||||
if shard_init:
|
if shard_init:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
f'Shard init is not supported model.from_pretrained() yet. Please load weights after strategy.prepare()'
|
f'Shard init is not supported model.from_pretrained() yet. Please load weights after strategy.prepare()'
|
||||||
)
|
)
|
||||||
|
if stage == 3 and precision == 'fp32':
|
||||||
|
warnings.warn(f'Stage 3 only supports fp16. Precision is set to fp16.')
|
||||||
|
precision = 'fp16'
|
||||||
|
self.precision = precision
|
||||||
self.shard_init = shard_init
|
self.shard_init = shard_init
|
||||||
self.gemini_config = dict(device=get_current_device(),
|
self.gemini_config = dict(device=get_current_device(),
|
||||||
placement_policy=placement_policy,
|
placement_policy=placement_policy,
|
||||||
|
@ -127,7 +134,10 @@ class ColossalAIStrategy(DDPStrategy):
|
||||||
return super().model_init_context()
|
return super().model_init_context()
|
||||||
|
|
||||||
def setup_model(self, model: nn.Module) -> nn.Module:
|
def setup_model(self, model: nn.Module) -> nn.Module:
|
||||||
return zero_model_wrapper(model, zero_stage=self.stage, gemini_config=self.gemini_config)
|
model = zero_model_wrapper(model, zero_stage=self.stage, gemini_config=self.gemini_config)
|
||||||
|
if self.stage != 3 and self.precision == 'fp16':
|
||||||
|
model = model.half()
|
||||||
|
return model
|
||||||
|
|
||||||
def setup_optimizer(self, optimizer: optim.Optimizer, model: nn.Module) -> optim.Optimizer:
|
def setup_optimizer(self, optimizer: optim.Optimizer, model: nn.Module) -> optim.Optimizer:
|
||||||
assert isinstance(optimizer, (CPUAdam, HybridAdam)), f'Unsupported optimizer {type(optimizer)}'
|
assert isinstance(optimizer, (CPUAdam, HybridAdam)), f'Unsupported optimizer {type(optimizer)}'
|
||||||
|
@ -159,7 +169,7 @@ class ColossalAIStrategy(DDPStrategy):
|
||||||
# merge lora_weights into weights
|
# merge lora_weights into weights
|
||||||
for module in unwrapped_model.modules():
|
for module in unwrapped_model.modules():
|
||||||
if isinstance(module, LoraLinear):
|
if isinstance(module, LoraLinear):
|
||||||
module.merge_weights=True
|
module.merge_weights = True
|
||||||
module.eval()
|
module.eval()
|
||||||
# get state_dict and save
|
# get state_dict and save
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue