mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
143 lines
5.3 KiB
143 lines
5.3 KiB
2 years ago
|
import os
|
||
|
from collections import OrderedDict
|
||
1 year ago
|
from typing import Any, Dict
|
||
2 years ago
|
|
||
|
import torch
|
||
|
import torch.distributed as dist
|
||
|
import torch.nn as nn
|
||
|
from coati.models.bloom import BLOOMRM, BLOOMActor, BLOOMCritic
|
||
|
from coati.models.gpt import GPTRM, GPTActor, GPTCritic
|
||
|
from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM
|
||
|
from coati.models.opt import OPTRM, OPTActor, OPTCritic
|
||
1 year ago
|
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy
|
||
1 year ago
|
from transformers import AutoTokenizer, BloomTokenizerFast, GPT2Tokenizer
|
||
2 years ago
|
|
||
|
|
||
|
def is_rank_0() -> bool:
|
||
|
return not dist.is_initialized() or dist.get_rank() == 0
|
||
|
|
||
|
|
||
|
def get_rank() -> int:
|
||
|
return dist.get_rank() if dist.is_initialized() else 0
|
||
|
|
||
|
|
||
|
def get_world_size() -> int:
|
||
|
return dist.get_world_size() if dist.is_initialized() else 1
|
||
|
|
||
|
|
||
|
def get_actor_from_args(model: str, pretrained: str = None, config=None, lora_rank=0):
|
||
1 year ago
|
if model == "gpt2":
|
||
2 years ago
|
actor = GPTActor(pretrained=pretrained, config=config, lora_rank=lora_rank)
|
||
1 year ago
|
elif model == "bloom":
|
||
2 years ago
|
actor = BLOOMActor(pretrained=pretrained, config=config, lora_rank=lora_rank)
|
||
1 year ago
|
elif model == "opt":
|
||
2 years ago
|
actor = OPTActor(pretrained=pretrained, config=config, lora_rank=lora_rank)
|
||
1 year ago
|
elif model == "llama":
|
||
2 years ago
|
actor = LlamaActor(pretrained=pretrained, config=config, lora_rank=lora_rank)
|
||
|
else:
|
||
|
raise ValueError(f'Unsupported actor model "{model}"')
|
||
|
return actor
|
||
|
|
||
|
|
||
|
def get_critic_from_args(model: str, pretrained: str = None, config=None, lora_rank=0):
|
||
1 year ago
|
if model == "gpt2":
|
||
1 year ago
|
critic = GPTCritic(pretrained=pretrained, lora_rank=lora_rank, config=config)
|
||
1 year ago
|
elif model == "bloom":
|
||
1 year ago
|
critic = BLOOMCritic(pretrained=pretrained, lora_rank=lora_rank, config=config)
|
||
1 year ago
|
elif model == "opt":
|
||
1 year ago
|
critic = OPTCritic(pretrained=pretrained, lora_rank=lora_rank, config=config)
|
||
1 year ago
|
elif model == "llama":
|
||
1 year ago
|
critic = LlamaCritic(pretrained=pretrained, lora_rank=lora_rank, config=config)
|
||
2 years ago
|
else:
|
||
|
raise ValueError(f'Unsupported reward model "{model}"')
|
||
|
return critic
|
||
|
|
||
|
|
||
|
def get_reward_model_from_args(model: str, pretrained: str = None, config=None):
|
||
1 year ago
|
if model == "gpt2":
|
||
2 years ago
|
reward_model = GPTRM(pretrained=pretrained, config=config)
|
||
1 year ago
|
elif model == "bloom":
|
||
2 years ago
|
reward_model = BLOOMRM(pretrained=pretrained, config=config)
|
||
1 year ago
|
elif model == "opt":
|
||
2 years ago
|
reward_model = OPTRM(pretrained=pretrained, config=config)
|
||
1 year ago
|
elif model == "llama":
|
||
2 years ago
|
reward_model = LlamaRM(pretrained=pretrained, config=config)
|
||
|
else:
|
||
|
raise ValueError(f'Unsupported reward model "{model}"')
|
||
|
return reward_model
|
||
|
|
||
|
|
||
|
def get_strategy_from_args(strategy: str):
|
||
1 year ago
|
if strategy == "ddp":
|
||
2 years ago
|
strategy_ = DDPStrategy()
|
||
1 year ago
|
elif strategy == "colossalai_gemini":
|
||
1 year ago
|
strategy_ = GeminiStrategy(placement_policy="static", initial_scale=2**5)
|
||
1 year ago
|
elif strategy == "colossalai_zero2":
|
||
|
strategy_ = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
|
||
|
elif strategy == "colossalai_gemini_cpu":
|
||
8 months ago
|
strategy_ = GeminiStrategy(
|
||
|
placement_policy="static", offload_optim_frac=1.0, offload_param_frac=1.0, initial_scale=2**5
|
||
|
)
|
||
1 year ago
|
elif strategy == "colossalai_zero2_cpu":
|
||
|
strategy_ = LowLevelZeroStrategy(stage=2, placement_policy="cpu")
|
||
2 years ago
|
else:
|
||
|
raise ValueError(f'Unsupported strategy "{strategy}"')
|
||
|
return strategy_
|
||
|
|
||
|
|
||
|
def get_tokenizer_from_args(model: str, **kwargs):
|
||
1 year ago
|
if model == "gpt2":
|
||
|
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
||
|
elif model == "bloom":
|
||
|
tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-560m")
|
||
|
elif model == "opt":
|
||
2 years ago
|
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
||
1 year ago
|
elif model == "llama":
|
||
2 years ago
|
pretrain_path = kwargs["pretrain"]
|
||
|
tokenizer = AutoTokenizer.from_pretrained(pretrain_path)
|
||
|
else:
|
||
|
raise ValueError(f'Unsupported model "{model}"')
|
||
|
|
||
|
tokenizer.pad_token = tokenizer.eos_token
|
||
|
return tokenizer
|
||
|
|
||
|
|
||
|
def set_dist_env(env_info: Dict[str, str]):
|
||
1 year ago
|
os.environ["RANK"] = env_info["rank"]
|
||
|
os.environ["LOCAL_RANK"] = env_info["local_rank"]
|
||
|
os.environ["WORLD_SIZE"] = env_info["world_size"]
|
||
|
os.environ["MASTER_PORT"] = env_info["master_port"]
|
||
|
os.environ["MASTER_ADDR"] = env_info["master_addr"]
|
||
2 years ago
|
|
||
|
|
||
|
def get_model_numel(model: nn.Module) -> int:
|
||
|
numel = sum(p.numel() for p in model.parameters())
|
||
|
return numel
|
||
|
|
||
|
|
||
|
def get_receivers_per_sender(sender_idx: int, num_senders: int, num_receivers: int, allow_idle_sender: bool) -> list:
|
||
|
target_receivers = []
|
||
|
if num_senders <= num_receivers or allow_idle_sender:
|
||
1 year ago
|
# a sender will send data to one or more receivers
|
||
2 years ago
|
# a receiver only has one sender
|
||
|
for i in range(num_receivers):
|
||
|
if i % num_senders == sender_idx:
|
||
|
target_receivers.append(i)
|
||
|
else:
|
||
|
# a sender will send data to one receiver
|
||
|
# a receiver may have more than one sender
|
||
|
target_receivers.append(sender_idx % num_receivers)
|
||
|
return target_receivers
|
||
|
|
||
|
|
||
1 year ago
|
def state_dict_to(
|
||
|
state_dict: Dict[str, Any], dtype: torch.dtype = torch.float16, device: torch.device = torch.device("cpu")
|
||
|
):
|
||
|
"""
|
||
|
keep state_dict intact
|
||
|
"""
|
||
2 years ago
|
new_state_dict = OrderedDict()
|
||
|
for k, v in state_dict.items():
|
||
|
new_state_dict[k] = v.to(dtype=dtype, device=device)
|
||
|
return new_state_dict
|