|
|
@ -30,6 +30,7 @@ class ColossalAIStrategy(DDPStrategy): |
|
|
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
Args: |
|
|
|
stage(int): The stage to use in ZeRO. Choose in (1, 2, 3) |
|
|
|
stage(int): The stage to use in ZeRO. Choose in (1, 2, 3) |
|
|
|
|
|
|
|
precision(str): The precision to use. Choose in ('fp32', 'fp16'). Stage 3 only supports fp16. |
|
|
|
seed(int): The seed for the random number generator. |
|
|
|
seed(int): The seed for the random number generator. |
|
|
|
shard_init(bool): Whether to shard the model parameters during initialization. Only for ZeRO-3. |
|
|
|
shard_init(bool): Whether to shard the model parameters during initialization. Only for ZeRO-3. |
|
|
|
This is not compativle with `from_pretrained()`. We temporarily disable this and will support it in the future. |
|
|
|
This is not compativle with `from_pretrained()`. We temporarily disable this and will support it in the future. |
|
|
@ -59,6 +60,7 @@ class ColossalAIStrategy(DDPStrategy): |
|
|
|
def __init__( |
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
self, |
|
|
|
stage: int = 3, |
|
|
|
stage: int = 3, |
|
|
|
|
|
|
|
precision: str = 'fp16', |
|
|
|
seed: int = 42, |
|
|
|
seed: int = 42, |
|
|
|
shard_init: bool = False, # only for stage 3 |
|
|
|
shard_init: bool = False, # only for stage 3 |
|
|
|
placement_policy: str = 'cuda', |
|
|
|
placement_policy: str = 'cuda', |
|
|
@ -81,12 +83,17 @@ class ColossalAIStrategy(DDPStrategy): |
|
|
|
norm_type: float = 2.0) -> None: |
|
|
|
norm_type: float = 2.0) -> None: |
|
|
|
super().__init__(seed) |
|
|
|
super().__init__(seed) |
|
|
|
assert placement_policy in ('cpu', 'cuda'), f'Unsupported placement policy "{placement_policy}"' |
|
|
|
assert placement_policy in ('cpu', 'cuda'), f'Unsupported placement policy "{placement_policy}"' |
|
|
|
|
|
|
|
assert precision in ('fp32', 'fp16'), f'Unsupported precision "{precision}"' |
|
|
|
self.stage = stage |
|
|
|
self.stage = stage |
|
|
|
# TODO(ver217): support shard_init when using from_pretrained() |
|
|
|
# TODO(ver217): support shard_init when using from_pretrained() |
|
|
|
if shard_init: |
|
|
|
if shard_init: |
|
|
|
warnings.warn( |
|
|
|
warnings.warn( |
|
|
|
f'Shard init is not supported model.from_pretrained() yet. Please load weights after strategy.prepare()' |
|
|
|
f'Shard init is not supported model.from_pretrained() yet. Please load weights after strategy.prepare()' |
|
|
|
) |
|
|
|
) |
|
|
|
|
|
|
|
if stage == 3 and precision == 'fp32': |
|
|
|
|
|
|
|
warnings.warn(f'Stage 3 only supports fp16. Precision is set to fp16.') |
|
|
|
|
|
|
|
precision = 'fp16' |
|
|
|
|
|
|
|
self.precision = precision |
|
|
|
self.shard_init = shard_init |
|
|
|
self.shard_init = shard_init |
|
|
|
self.gemini_config = dict(device=get_current_device(), |
|
|
|
self.gemini_config = dict(device=get_current_device(), |
|
|
|
placement_policy=placement_policy, |
|
|
|
placement_policy=placement_policy, |
|
|
@ -127,7 +134,10 @@ class ColossalAIStrategy(DDPStrategy): |
|
|
|
return super().model_init_context() |
|
|
|
return super().model_init_context() |
|
|
|
|
|
|
|
|
|
|
|
def setup_model(self, model: nn.Module) -> nn.Module: |
|
|
|
def setup_model(self, model: nn.Module) -> nn.Module: |
|
|
|
return zero_model_wrapper(model, zero_stage=self.stage, gemini_config=self.gemini_config) |
|
|
|
model = zero_model_wrapper(model, zero_stage=self.stage, gemini_config=self.gemini_config) |
|
|
|
|
|
|
|
if self.stage != 3 and self.precision == 'fp16': |
|
|
|
|
|
|
|
model = model.half() |
|
|
|
|
|
|
|
return model |
|
|
|
|
|
|
|
|
|
|
|
def setup_optimizer(self, optimizer: optim.Optimizer, model: nn.Module) -> optim.Optimizer: |
|
|
|
def setup_optimizer(self, optimizer: optim.Optimizer, model: nn.Module) -> optim.Optimizer: |
|
|
|
assert isinstance(optimizer, (CPUAdam, HybridAdam)), f'Unsupported optimizer {type(optimizer)}' |
|
|
|
assert isinstance(optimizer, (CPUAdam, HybridAdam)), f'Unsupported optimizer {type(optimizer)}' |
|
|
@ -159,7 +169,7 @@ class ColossalAIStrategy(DDPStrategy): |
|
|
|
# merge lora_weights into weights |
|
|
|
# merge lora_weights into weights |
|
|
|
for module in unwrapped_model.modules(): |
|
|
|
for module in unwrapped_model.modules(): |
|
|
|
if isinstance(module, LoraLinear): |
|
|
|
if isinstance(module, LoraLinear): |
|
|
|
module.merge_weights=True |
|
|
|
module.merge_weights = True |
|
|
|
module.eval() |
|
|
|
module.eval() |
|
|
|
# get state_dict and save |
|
|
|
# get state_dict and save |
|
|
|
|
|
|
|
|
|
|
|