mirror of https://github.com/hpcaitech/ColossalAI
19 lines
580 B
Python
19 lines
580 B
Python
#!/usr/bin/env python3
|
|
# -*- coding: utf-8 -*-
|
|
|
|
from transformers.models.llama import LlamaForCausalLM
|
|
|
|
|
|
def freeze_non_embeds_parameters(model: LlamaForCausalLM) -> None:
|
|
"""Freeze all parameters except embeddings."""
|
|
for name, params in model.named_parameters():
|
|
if "embed_tokens" not in name and "lm_head" not in name:
|
|
params.requires_grad = False
|
|
else:
|
|
params.requires_grad = True
|
|
|
|
|
|
def unfreeze_parameters(model: LlamaForCausalLM) -> None:
|
|
for name, params in model.named_parameters():
|
|
params.requires_grad = False
|