change nn to models (#3032)

pull/3037/head
Fazzie-Maqianli 2023-03-07 16:34:22 +08:00 committed by GitHub
parent 4269196c79
commit c21b11edce
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
39 changed files with 72 additions and 50 deletions

View File

@ -41,7 +41,8 @@ Simplest usage:
```python
from chatgpt.trainer import PPOTrainer
from chatgpt.trainer.strategies import ColossalAIStrategy
from chatgpt.nn import GPTActor, GPTCritic, RewardModel
from chatgpt.models.gpt import GPTActor, GPTCritic
from chatgpt.models.base import RewardModel
from copy import deepcopy
from colossalai.nn.optimizer import HybridAdam

View File

@ -4,7 +4,8 @@ from copy import deepcopy
import torch
import torch.distributed as dist
import torch.nn as nn
from chatgpt.nn import GPTActor, GPTCritic, RewardModel
from chatgpt.models.base import RewardModel
from chatgpt.models.gpt import GPTActor, GPTCritic
from chatgpt.trainer import PPOTrainer
from chatgpt.trainer.callbacks import PerformanceEvaluator
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, Strategy

View File

@ -4,7 +4,8 @@ from copy import deepcopy
import torch
import torch.distributed as dist
import torch.nn as nn
from chatgpt.nn import OPTActor, OPTCritic, RewardModel
from chatgpt.models.base import RewardModel
from chatgpt.models.opt import OPTActor, OPTCritic
from chatgpt.trainer import PPOTrainer
from chatgpt.trainer.callbacks import PerformanceEvaluator
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, Strategy

View File

@ -4,7 +4,7 @@ from typing import Optional
import torch
import torch.nn as nn
from chatgpt.nn.actor import Actor
from chatgpt.models.base import Actor
@dataclass

View File

@ -1,5 +1,5 @@
import torch
from chatgpt.nn.utils import compute_reward, normalize
from chatgpt.models.utils import compute_reward, normalize
from .base import Experience, ExperienceMaker

View File

@ -0,0 +1,4 @@
from .base import Actor, Critic, RewardModel
from .loss import PairWiseLoss, PolicyLoss, PPOPtxActorLoss, ValueLoss
__all__ = ['Actor', 'Critic', 'RewardModel', 'PolicyLoss', 'ValueLoss', 'PPOPtxActorLoss', 'PairWiseLoss']

View File

@ -0,0 +1,5 @@
from .actor import Actor
from .critic import Critic
from .reward_model import RewardModel
__all__ = ['Actor', 'Critic', 'RewardModel']

View File

@ -4,9 +4,9 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from .generation import generate
from .lora import LoRAModule
from .utils import log_probs_from_logits
from ..generation import generate
from ..lora import LoRAModule
from ..utils import log_probs_from_logits
class Actor(LoRAModule):

View File

@ -3,8 +3,8 @@ from typing import Optional
import torch
import torch.nn as nn
from .lora import LoRAModule
from .utils import masked_mean
from ..lora import LoRAModule
from ..utils import masked_mean
class Critic(LoRAModule):

View File

@ -3,7 +3,7 @@ from typing import Optional
import torch
import torch.nn as nn
from .lora import LoRAModule
from ..lora import LoRAModule
class RewardModel(LoRAModule):

View File

@ -0,0 +1,5 @@
from .bloom_actor import BLOOMActor
from .bloom_critic import BLOOMCritic
from .bloom_rm import BLOOMRM
__all__ = ['BLOOMActor', 'BLOOMCritic', 'BLOOMRM']

View File

@ -3,7 +3,7 @@ from typing import Optional
import torch
from transformers import BloomConfig, BloomForCausalLM, BloomModel
from .actor import Actor
from ..base import Actor
class BLOOMActor(Actor):

View File

@ -4,7 +4,7 @@ import torch
import torch.nn as nn
from transformers import BloomConfig, BloomForCausalLM, BloomModel
from .critic import Critic
from ..base import Critic
class BLOOMCritic(Critic):

View File

@ -3,7 +3,7 @@ from typing import Optional
import torch.nn as nn
from transformers import BloomConfig, BloomForCausalLM, BloomModel
from .reward_model import RewardModel
from ..base import RewardModel
class BLOOMRM(RewardModel):

View File

@ -0,0 +1,5 @@
from .gpt_actor import GPTActor
from .gpt_critic import GPTCritic
from .gpt_rm import GPTRM
__all__ = ['GPTActor', 'GPTCritic', 'GPTRM']

View File

@ -3,7 +3,7 @@ from typing import Optional
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
from .actor import Actor
from ..base import Actor
class GPTActor(Actor):

View File

@ -4,7 +4,7 @@ import torch.nn as nn
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
from transformers.models.gpt2.modeling_gpt2 import GPT2Model
from .critic import Critic
from ..base import Critic
class GPTCritic(Critic):

View File

@ -4,7 +4,7 @@ import torch.nn as nn
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
from transformers.models.gpt2.modeling_gpt2 import GPT2Model
from .reward_model import RewardModel
from ..base import RewardModel
class GPTRM(RewardModel):

View File

@ -0,0 +1,5 @@
from .opt_actor import OPTActor
from .opt_critic import OPTCritic
from .opt_rm import OPTRM
__all__ = ['OPTActor', 'OPTCritic', 'OPTRM']

View File

@ -3,7 +3,7 @@ from typing import Optional
from transformers.models.opt.configuration_opt import OPTConfig
from transformers.models.opt.modeling_opt import OPTForCausalLM
from .actor import Actor
from ..base import Actor
class OPTActor(Actor):

View File

@ -4,7 +4,7 @@ import torch.nn as nn
from transformers.models.opt.configuration_opt import OPTConfig
from transformers.models.opt.modeling_opt import OPTModel
from .critic import Critic
from ..base import Critic
class OPTCritic(Critic):

View File

@ -3,7 +3,7 @@ from typing import Optional
import torch.nn as nn
from transformers import OPTConfig, OPTModel
from .reward_model import RewardModel
from ..base import RewardModel
class OPTRM(RewardModel):

View File

@ -1,18 +0,0 @@
from .actor import Actor
from .bloom_actor import BLOOMActor
from .bloom_critic import BLOOMCritic
from .bloom_rm import BLOOMRM
from .critic import Critic
from .gpt_actor import GPTActor
from .gpt_critic import GPTCritic
from .gpt_rm import GPTRM
from .loss import PairWiseLoss, PolicyLoss, PPOPtxActorLoss, ValueLoss
from .opt_actor import OPTActor
from .opt_critic import OPTCritic
from .opt_rm import OPTRM
from .reward_model import RewardModel
__all__ = [
'Actor', 'Critic', 'RewardModel', 'PolicyLoss', 'ValueLoss', 'PPOPtxActorLoss', 'PairWiseLoss', 'GPTActor',
'GPTCritic', 'GPTRM', 'BLOOMActor', 'BLOOMCritic', 'BLOOMRM', 'OPTActor', 'OPTCritic', 'OPTRM'
]

View File

@ -2,8 +2,9 @@ from typing import Any, Callable, Dict, List, Optional
import torch.nn as nn
from chatgpt.experience_maker import Experience, NaiveExperienceMaker
from chatgpt.nn import Actor, Critic, PolicyLoss, ValueLoss
from chatgpt.nn.generation_utils import update_model_kwargs_fn
from chatgpt.models.base import Actor, Critic
from chatgpt.models.generation_utils import update_model_kwargs_fn
from chatgpt.models.loss import PolicyLoss, ValueLoss
from chatgpt.replay_buffer import NaiveReplayBuffer
from torch.optim import Optimizer

View File

@ -3,7 +3,7 @@ from abc import ABC
import loralib as lora
import torch
from chatgpt.dataset import RewardDataset
from chatgpt.nn import PairWiseLoss
from chatgpt.models.loss import PairWiseLoss
from torch.optim import Adam, Optimizer
from torch.utils.data import DataLoader
from tqdm import tqdm

View File

@ -5,7 +5,7 @@ from typing import Any, List, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
from chatgpt.nn import Actor
from chatgpt.models.base import Actor, Critic, RewardModel
from chatgpt.replay_buffer import ReplayBuffer
from torch.optim import Optimizer
from torch.utils.data import DataLoader

View File

@ -5,7 +5,7 @@ import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
from chatgpt.nn import Actor
from chatgpt.models.base import Actor
from torch.optim import Optimizer
import colossalai

View File

@ -5,7 +5,7 @@ import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
from chatgpt.nn import Actor
from chatgpt.models.base import Actor
from chatgpt.replay_buffer import ReplayBuffer
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer

View File

@ -1,7 +1,9 @@
import argparse
import torch
from chatgpt.nn import BLOOMActor, GPTActor, OPTActor
from chatgpt.models.bloom import BLOOMActor
from chatgpt.models.gpt import GPTActor
from chatgpt.models.opt import OPTActor
from transformers import AutoTokenizer
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer

View File

@ -2,7 +2,10 @@ import argparse
from copy import deepcopy
import torch
from chatgpt.nn import BLOOMActor, BLOOMCritic, GPTActor, GPTCritic, OPTActor, OPTCritic, RewardModel
from chatgpt.models.base import RewardModel
from chatgpt.models.bloom import BLOOMActor, BLOOMCritic
from chatgpt.models.gpt import GPTActor, GPTCritic
from chatgpt.models.opt import OPTActor, OPTCritic
from chatgpt.trainer import PPOTrainer
from chatgpt.trainer.callbacks import SaveCheckpoint
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy

View File

@ -3,7 +3,10 @@ from copy import deepcopy
import pandas as pd
import torch
from chatgpt.nn import BLOOMActor, BLOOMCritic, GPTActor, GPTCritic, OPTActor, OPTCritic, RewardModel
from chatgpt.models.base import RewardModel
from chatgpt.models.bloom import BLOOMActor, BLOOMCritic
from chatgpt.models.gpt import GPTActor, GPTCritic
from chatgpt.models.opt import OPTActor, OPTCritic
from chatgpt.trainer import PPOTrainer
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
from torch.optim import Adam

View File

@ -3,7 +3,10 @@ import argparse
import loralib as lora
import torch
from chatgpt.dataset import RewardDataset
from chatgpt.nn import BLOOMRM, GPTRM, OPTRM
from chatgpt.models.base import RewardModel
from chatgpt.models.bloom import BLOOMRM
from chatgpt.models.gpt import GPTRM
from chatgpt.models.opt import OPTRM
from chatgpt.trainer import RewardModelTrainer
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
from datasets import load_dataset

View File

@ -7,7 +7,7 @@ import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from chatgpt.nn import GPTActor
from chatgpt.models.gpt import GPTActor
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy
from transformers.models.gpt2.configuration_gpt2 import GPT2Config

View File

@ -7,7 +7,8 @@ import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from chatgpt.experience_maker import NaiveExperienceMaker
from chatgpt.nn import GPTActor, GPTCritic, RewardModel
from chatgpt.models.base import RewardModel
from chatgpt.models.gpt import GPTActor, GPTCritic
from chatgpt.replay_buffer import NaiveReplayBuffer
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy
from transformers.models.gpt2.configuration_gpt2 import GPT2Config