mirror of https://github.com/hpcaitech/ColossalAI
change nn to models (#3032)
parent
4269196c79
commit
c21b11edce
|
@ -41,7 +41,8 @@ Simplest usage:
|
|||
```python
|
||||
from chatgpt.trainer import PPOTrainer
|
||||
from chatgpt.trainer.strategies import ColossalAIStrategy
|
||||
from chatgpt.nn import GPTActor, GPTCritic, RewardModel
|
||||
from chatgpt.models.gpt import GPTActor, GPTCritic
|
||||
from chatgpt.models.base import RewardModel
|
||||
from copy import deepcopy
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
|
||||
|
|
|
@ -4,7 +4,8 @@ from copy import deepcopy
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from chatgpt.nn import GPTActor, GPTCritic, RewardModel
|
||||
from chatgpt.models.base import RewardModel
|
||||
from chatgpt.models.gpt import GPTActor, GPTCritic
|
||||
from chatgpt.trainer import PPOTrainer
|
||||
from chatgpt.trainer.callbacks import PerformanceEvaluator
|
||||
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, Strategy
|
||||
|
|
|
@ -4,7 +4,8 @@ from copy import deepcopy
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from chatgpt.nn import OPTActor, OPTCritic, RewardModel
|
||||
from chatgpt.models.base import RewardModel
|
||||
from chatgpt.models.opt import OPTActor, OPTCritic
|
||||
from chatgpt.trainer import PPOTrainer
|
||||
from chatgpt.trainer.callbacks import PerformanceEvaluator
|
||||
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, Strategy
|
||||
|
|
|
@ -4,7 +4,7 @@ from typing import Optional
|
|||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from chatgpt.nn.actor import Actor
|
||||
from chatgpt.models.base import Actor
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import torch
|
||||
from chatgpt.nn.utils import compute_reward, normalize
|
||||
from chatgpt.models.utils import compute_reward, normalize
|
||||
|
||||
from .base import Experience, ExperienceMaker
|
||||
|
||||
|
|
|
@ -0,0 +1,4 @@
|
|||
from .base import Actor, Critic, RewardModel
|
||||
from .loss import PairWiseLoss, PolicyLoss, PPOPtxActorLoss, ValueLoss
|
||||
|
||||
__all__ = ['Actor', 'Critic', 'RewardModel', 'PolicyLoss', 'ValueLoss', 'PPOPtxActorLoss', 'PairWiseLoss']
|
|
@ -0,0 +1,5 @@
|
|||
from .actor import Actor
|
||||
from .critic import Critic
|
||||
from .reward_model import RewardModel
|
||||
|
||||
__all__ = ['Actor', 'Critic', 'RewardModel']
|
|
@ -4,9 +4,9 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .generation import generate
|
||||
from .lora import LoRAModule
|
||||
from .utils import log_probs_from_logits
|
||||
from ..generation import generate
|
||||
from ..lora import LoRAModule
|
||||
from ..utils import log_probs_from_logits
|
||||
|
||||
|
||||
class Actor(LoRAModule):
|
|
@ -3,8 +3,8 @@ from typing import Optional
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .lora import LoRAModule
|
||||
from .utils import masked_mean
|
||||
from ..lora import LoRAModule
|
||||
from ..utils import masked_mean
|
||||
|
||||
|
||||
class Critic(LoRAModule):
|
|
@ -3,7 +3,7 @@ from typing import Optional
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .lora import LoRAModule
|
||||
from ..lora import LoRAModule
|
||||
|
||||
|
||||
class RewardModel(LoRAModule):
|
|
@ -0,0 +1,5 @@
|
|||
from .bloom_actor import BLOOMActor
|
||||
from .bloom_critic import BLOOMCritic
|
||||
from .bloom_rm import BLOOMRM
|
||||
|
||||
__all__ = ['BLOOMActor', 'BLOOMCritic', 'BLOOMRM']
|
|
@ -3,7 +3,7 @@ from typing import Optional
|
|||
import torch
|
||||
from transformers import BloomConfig, BloomForCausalLM, BloomModel
|
||||
|
||||
from .actor import Actor
|
||||
from ..base import Actor
|
||||
|
||||
|
||||
class BLOOMActor(Actor):
|
|
@ -4,7 +4,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
from transformers import BloomConfig, BloomForCausalLM, BloomModel
|
||||
|
||||
from .critic import Critic
|
||||
from ..base import Critic
|
||||
|
||||
|
||||
class BLOOMCritic(Critic):
|
|
@ -3,7 +3,7 @@ from typing import Optional
|
|||
import torch.nn as nn
|
||||
from transformers import BloomConfig, BloomForCausalLM, BloomModel
|
||||
|
||||
from .reward_model import RewardModel
|
||||
from ..base import RewardModel
|
||||
|
||||
|
||||
class BLOOMRM(RewardModel):
|
|
@ -0,0 +1,5 @@
|
|||
from .gpt_actor import GPTActor
|
||||
from .gpt_critic import GPTCritic
|
||||
from .gpt_rm import GPTRM
|
||||
|
||||
__all__ = ['GPTActor', 'GPTCritic', 'GPTRM']
|
|
@ -3,7 +3,7 @@ from typing import Optional
|
|||
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
|
||||
|
||||
from .actor import Actor
|
||||
from ..base import Actor
|
||||
|
||||
|
||||
class GPTActor(Actor):
|
|
@ -4,7 +4,7 @@ import torch.nn as nn
|
|||
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2Model
|
||||
|
||||
from .critic import Critic
|
||||
from ..base import Critic
|
||||
|
||||
|
||||
class GPTCritic(Critic):
|
|
@ -4,7 +4,7 @@ import torch.nn as nn
|
|||
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2Model
|
||||
|
||||
from .reward_model import RewardModel
|
||||
from ..base import RewardModel
|
||||
|
||||
|
||||
class GPTRM(RewardModel):
|
|
@ -0,0 +1,5 @@
|
|||
from .opt_actor import OPTActor
|
||||
from .opt_critic import OPTCritic
|
||||
from .opt_rm import OPTRM
|
||||
|
||||
__all__ = ['OPTActor', 'OPTCritic', 'OPTRM']
|
|
@ -3,7 +3,7 @@ from typing import Optional
|
|||
from transformers.models.opt.configuration_opt import OPTConfig
|
||||
from transformers.models.opt.modeling_opt import OPTForCausalLM
|
||||
|
||||
from .actor import Actor
|
||||
from ..base import Actor
|
||||
|
||||
|
||||
class OPTActor(Actor):
|
|
@ -4,7 +4,7 @@ import torch.nn as nn
|
|||
from transformers.models.opt.configuration_opt import OPTConfig
|
||||
from transformers.models.opt.modeling_opt import OPTModel
|
||||
|
||||
from .critic import Critic
|
||||
from ..base import Critic
|
||||
|
||||
|
||||
class OPTCritic(Critic):
|
|
@ -3,7 +3,7 @@ from typing import Optional
|
|||
import torch.nn as nn
|
||||
from transformers import OPTConfig, OPTModel
|
||||
|
||||
from .reward_model import RewardModel
|
||||
from ..base import RewardModel
|
||||
|
||||
|
||||
class OPTRM(RewardModel):
|
|
@ -1,18 +0,0 @@
|
|||
from .actor import Actor
|
||||
from .bloom_actor import BLOOMActor
|
||||
from .bloom_critic import BLOOMCritic
|
||||
from .bloom_rm import BLOOMRM
|
||||
from .critic import Critic
|
||||
from .gpt_actor import GPTActor
|
||||
from .gpt_critic import GPTCritic
|
||||
from .gpt_rm import GPTRM
|
||||
from .loss import PairWiseLoss, PolicyLoss, PPOPtxActorLoss, ValueLoss
|
||||
from .opt_actor import OPTActor
|
||||
from .opt_critic import OPTCritic
|
||||
from .opt_rm import OPTRM
|
||||
from .reward_model import RewardModel
|
||||
|
||||
__all__ = [
|
||||
'Actor', 'Critic', 'RewardModel', 'PolicyLoss', 'ValueLoss', 'PPOPtxActorLoss', 'PairWiseLoss', 'GPTActor',
|
||||
'GPTCritic', 'GPTRM', 'BLOOMActor', 'BLOOMCritic', 'BLOOMRM', 'OPTActor', 'OPTCritic', 'OPTRM'
|
||||
]
|
|
@ -2,8 +2,9 @@ from typing import Any, Callable, Dict, List, Optional
|
|||
|
||||
import torch.nn as nn
|
||||
from chatgpt.experience_maker import Experience, NaiveExperienceMaker
|
||||
from chatgpt.nn import Actor, Critic, PolicyLoss, ValueLoss
|
||||
from chatgpt.nn.generation_utils import update_model_kwargs_fn
|
||||
from chatgpt.models.base import Actor, Critic
|
||||
from chatgpt.models.generation_utils import update_model_kwargs_fn
|
||||
from chatgpt.models.loss import PolicyLoss, ValueLoss
|
||||
from chatgpt.replay_buffer import NaiveReplayBuffer
|
||||
from torch.optim import Optimizer
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@ from abc import ABC
|
|||
import loralib as lora
|
||||
import torch
|
||||
from chatgpt.dataset import RewardDataset
|
||||
from chatgpt.nn import PairWiseLoss
|
||||
from chatgpt.models.loss import PairWiseLoss
|
||||
from torch.optim import Adam, Optimizer
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
|
|
|
@ -5,7 +5,7 @@ from typing import Any, List, Tuple, Union
|
|||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from chatgpt.nn import Actor
|
||||
from chatgpt.models.base import Actor, Critic, RewardModel
|
||||
from chatgpt.replay_buffer import ReplayBuffer
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils.data import DataLoader
|
||||
|
|
|
@ -5,7 +5,7 @@ import torch
|
|||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from chatgpt.nn import Actor
|
||||
from chatgpt.models.base import Actor
|
||||
from torch.optim import Optimizer
|
||||
|
||||
import colossalai
|
||||
|
|
|
@ -5,7 +5,7 @@ import numpy as np
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from chatgpt.nn import Actor
|
||||
from chatgpt.models.base import Actor
|
||||
from chatgpt.replay_buffer import ReplayBuffer
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.optim import Optimizer
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
import argparse
|
||||
|
||||
import torch
|
||||
from chatgpt.nn import BLOOMActor, GPTActor, OPTActor
|
||||
from chatgpt.models.bloom import BLOOMActor
|
||||
from chatgpt.models.gpt import GPTActor
|
||||
from chatgpt.models.opt import OPTActor
|
||||
from transformers import AutoTokenizer
|
||||
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
|
||||
|
||||
|
|
|
@ -2,7 +2,10 @@ import argparse
|
|||
from copy import deepcopy
|
||||
|
||||
import torch
|
||||
from chatgpt.nn import BLOOMActor, BLOOMCritic, GPTActor, GPTCritic, OPTActor, OPTCritic, RewardModel
|
||||
from chatgpt.models.base import RewardModel
|
||||
from chatgpt.models.bloom import BLOOMActor, BLOOMCritic
|
||||
from chatgpt.models.gpt import GPTActor, GPTCritic
|
||||
from chatgpt.models.opt import OPTActor, OPTCritic
|
||||
from chatgpt.trainer import PPOTrainer
|
||||
from chatgpt.trainer.callbacks import SaveCheckpoint
|
||||
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
|
||||
|
|
|
@ -3,7 +3,10 @@ from copy import deepcopy
|
|||
|
||||
import pandas as pd
|
||||
import torch
|
||||
from chatgpt.nn import BLOOMActor, BLOOMCritic, GPTActor, GPTCritic, OPTActor, OPTCritic, RewardModel
|
||||
from chatgpt.models.base import RewardModel
|
||||
from chatgpt.models.bloom import BLOOMActor, BLOOMCritic
|
||||
from chatgpt.models.gpt import GPTActor, GPTCritic
|
||||
from chatgpt.models.opt import OPTActor, OPTCritic
|
||||
from chatgpt.trainer import PPOTrainer
|
||||
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
|
||||
from torch.optim import Adam
|
||||
|
|
|
@ -3,7 +3,10 @@ import argparse
|
|||
import loralib as lora
|
||||
import torch
|
||||
from chatgpt.dataset import RewardDataset
|
||||
from chatgpt.nn import BLOOMRM, GPTRM, OPTRM
|
||||
from chatgpt.models.base import RewardModel
|
||||
from chatgpt.models.bloom import BLOOMRM
|
||||
from chatgpt.models.gpt import GPTRM
|
||||
from chatgpt.models.opt import OPTRM
|
||||
from chatgpt.trainer import RewardModelTrainer
|
||||
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
|
||||
from datasets import load_dataset
|
||||
|
|
|
@ -7,7 +7,7 @@ import pytest
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
from chatgpt.nn import GPTActor
|
||||
from chatgpt.models.gpt import GPTActor
|
||||
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy
|
||||
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
|
||||
|
||||
|
|
|
@ -7,7 +7,8 @@ import torch
|
|||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
from chatgpt.experience_maker import NaiveExperienceMaker
|
||||
from chatgpt.nn import GPTActor, GPTCritic, RewardModel
|
||||
from chatgpt.models.base import RewardModel
|
||||
from chatgpt.models.gpt import GPTActor, GPTCritic
|
||||
from chatgpt.replay_buffer import NaiveReplayBuffer
|
||||
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy
|
||||
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
|
||||
|
|
Loading…
Reference in New Issue