From c21b11edce3b772cdbcb4e5fafe95f62ac49af94 Mon Sep 17 00:00:00 2001 From: Fazzie-Maqianli <55798671+Fazziekey@users.noreply.github.com> Date: Tue, 7 Mar 2023 16:34:22 +0800 Subject: [PATCH] change nn to models (#3032) --- applications/ChatGPT/README.md | 3 ++- .../ChatGPT/benchmarks/benchmark_gpt_dummy.py | 3 ++- .../benchmarks/benchmark_opt_lora_dummy.py | 3 ++- .../ChatGPT/chatgpt/experience_maker/base.py | 2 +- .../ChatGPT/chatgpt/experience_maker/naive.py | 2 +- .../ChatGPT/chatgpt/models/__init__.py | 4 ++++ .../ChatGPT/chatgpt/models/base/__init__.py | 5 +++++ .../chatgpt/{nn => models/base}/actor.py | 6 +++--- .../chatgpt/{nn => models/base}/critic.py | 4 ++-- .../{nn => models/base}/reward_model.py | 2 +- .../ChatGPT/chatgpt/models/bloom/__init__.py | 5 +++++ .../{nn => models/bloom}/bloom_actor.py | 2 +- .../{nn => models/bloom}/bloom_critic.py | 2 +- .../chatgpt/{nn => models/bloom}/bloom_rm.py | 2 +- .../chatgpt/{nn => models}/generation.py | 0 .../chatgpt/{nn => models}/generation_utils.py | 0 .../ChatGPT/chatgpt/models/gpt/__init__.py | 5 +++++ .../chatgpt/{nn => models/gpt}/gpt_actor.py | 2 +- .../chatgpt/{nn => models/gpt}/gpt_critic.py | 2 +- .../chatgpt/{nn => models/gpt}/gpt_rm.py | 2 +- .../ChatGPT/chatgpt/{nn => models}/lora.py | 0 .../ChatGPT/chatgpt/{nn => models}/loss.py | 0 .../ChatGPT/chatgpt/models/opt/__init__.py | 5 +++++ .../chatgpt/{nn => models/opt}/opt_actor.py | 2 +- .../chatgpt/{nn => models/opt}/opt_critic.py | 2 +- .../chatgpt/{nn => models/opt}/opt_rm.py | 2 +- .../ChatGPT/chatgpt/{nn => models}/utils.py | 0 applications/ChatGPT/chatgpt/nn/__init__.py | 18 ------------------ applications/ChatGPT/chatgpt/trainer/ppo.py | 5 +++-- applications/ChatGPT/chatgpt/trainer/rm.py | 2 +- .../ChatGPT/chatgpt/trainer/strategies/base.py | 2 +- .../chatgpt/trainer/strategies/colossalai.py | 2 +- .../ChatGPT/chatgpt/trainer/strategies/ddp.py | 2 +- applications/ChatGPT/examples/inference.py | 4 +++- applications/ChatGPT/examples/train_dummy.py | 5 ++++- applications/ChatGPT/examples/train_prompts.py | 5 ++++- .../ChatGPT/examples/train_reward_model.py | 5 ++++- applications/ChatGPT/tests/test_checkpoint.py | 2 +- applications/ChatGPT/tests/test_data.py | 3 ++- 39 files changed, 72 insertions(+), 50 deletions(-) create mode 100644 applications/ChatGPT/chatgpt/models/__init__.py create mode 100644 applications/ChatGPT/chatgpt/models/base/__init__.py rename applications/ChatGPT/chatgpt/{nn => models/base}/actor.py (95%) rename applications/ChatGPT/chatgpt/{nn => models/base}/critic.py (95%) rename applications/ChatGPT/chatgpt/{nn => models/base}/reward_model.py (97%) create mode 100644 applications/ChatGPT/chatgpt/models/bloom/__init__.py rename applications/ChatGPT/chatgpt/{nn => models/bloom}/bloom_actor.py (97%) rename applications/ChatGPT/chatgpt/{nn => models/bloom}/bloom_critic.py (97%) rename applications/ChatGPT/chatgpt/{nn => models/bloom}/bloom_rm.py (96%) rename applications/ChatGPT/chatgpt/{nn => models}/generation.py (100%) rename applications/ChatGPT/chatgpt/{nn => models}/generation_utils.py (100%) create mode 100644 applications/ChatGPT/chatgpt/models/gpt/__init__.py rename applications/ChatGPT/chatgpt/{nn => models/gpt}/gpt_actor.py (97%) rename applications/ChatGPT/chatgpt/{nn => models/gpt}/gpt_critic.py (97%) rename applications/ChatGPT/chatgpt/{nn => models/gpt}/gpt_rm.py (96%) rename applications/ChatGPT/chatgpt/{nn => models}/lora.py (100%) rename applications/ChatGPT/chatgpt/{nn => models}/loss.py (100%) create mode 100644 applications/ChatGPT/chatgpt/models/opt/__init__.py rename applications/ChatGPT/chatgpt/{nn => models/opt}/opt_actor.py (97%) rename applications/ChatGPT/chatgpt/{nn => models/opt}/opt_critic.py (97%) rename applications/ChatGPT/chatgpt/{nn => models/opt}/opt_rm.py (96%) rename applications/ChatGPT/chatgpt/{nn => models}/utils.py (100%) delete mode 100644 applications/ChatGPT/chatgpt/nn/__init__.py diff --git a/applications/ChatGPT/README.md b/applications/ChatGPT/README.md index d26206144..23c6aa372 100644 --- a/applications/ChatGPT/README.md +++ b/applications/ChatGPT/README.md @@ -41,7 +41,8 @@ Simplest usage: ```python from chatgpt.trainer import PPOTrainer from chatgpt.trainer.strategies import ColossalAIStrategy -from chatgpt.nn import GPTActor, GPTCritic, RewardModel +from chatgpt.models.gpt import GPTActor, GPTCritic +from chatgpt.models.base import RewardModel from copy import deepcopy from colossalai.nn.optimizer import HybridAdam diff --git a/applications/ChatGPT/benchmarks/benchmark_gpt_dummy.py b/applications/ChatGPT/benchmarks/benchmark_gpt_dummy.py index b5730c7c7..5ee65763b 100644 --- a/applications/ChatGPT/benchmarks/benchmark_gpt_dummy.py +++ b/applications/ChatGPT/benchmarks/benchmark_gpt_dummy.py @@ -4,7 +4,8 @@ from copy import deepcopy import torch import torch.distributed as dist import torch.nn as nn -from chatgpt.nn import GPTActor, GPTCritic, RewardModel +from chatgpt.models.base import RewardModel +from chatgpt.models.gpt import GPTActor, GPTCritic from chatgpt.trainer import PPOTrainer from chatgpt.trainer.callbacks import PerformanceEvaluator from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, Strategy diff --git a/applications/ChatGPT/benchmarks/benchmark_opt_lora_dummy.py b/applications/ChatGPT/benchmarks/benchmark_opt_lora_dummy.py index 6777cb770..207edbca9 100644 --- a/applications/ChatGPT/benchmarks/benchmark_opt_lora_dummy.py +++ b/applications/ChatGPT/benchmarks/benchmark_opt_lora_dummy.py @@ -4,7 +4,8 @@ from copy import deepcopy import torch import torch.distributed as dist import torch.nn as nn -from chatgpt.nn import OPTActor, OPTCritic, RewardModel +from chatgpt.models.base import RewardModel +from chatgpt.models.opt import OPTActor, OPTCritic from chatgpt.trainer import PPOTrainer from chatgpt.trainer.callbacks import PerformanceEvaluator from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, Strategy diff --git a/applications/ChatGPT/chatgpt/experience_maker/base.py b/applications/ChatGPT/chatgpt/experience_maker/base.py index 61895322c..f3640fc1e 100644 --- a/applications/ChatGPT/chatgpt/experience_maker/base.py +++ b/applications/ChatGPT/chatgpt/experience_maker/base.py @@ -4,7 +4,7 @@ from typing import Optional import torch import torch.nn as nn -from chatgpt.nn.actor import Actor +from chatgpt.models.base import Actor @dataclass diff --git a/applications/ChatGPT/chatgpt/experience_maker/naive.py b/applications/ChatGPT/chatgpt/experience_maker/naive.py index f4fd2078c..64835cfa1 100644 --- a/applications/ChatGPT/chatgpt/experience_maker/naive.py +++ b/applications/ChatGPT/chatgpt/experience_maker/naive.py @@ -1,5 +1,5 @@ import torch -from chatgpt.nn.utils import compute_reward, normalize +from chatgpt.models.utils import compute_reward, normalize from .base import Experience, ExperienceMaker diff --git a/applications/ChatGPT/chatgpt/models/__init__.py b/applications/ChatGPT/chatgpt/models/__init__.py new file mode 100644 index 000000000..376fed8de --- /dev/null +++ b/applications/ChatGPT/chatgpt/models/__init__.py @@ -0,0 +1,4 @@ +from .base import Actor, Critic, RewardModel +from .loss import PairWiseLoss, PolicyLoss, PPOPtxActorLoss, ValueLoss + +__all__ = ['Actor', 'Critic', 'RewardModel', 'PolicyLoss', 'ValueLoss', 'PPOPtxActorLoss', 'PairWiseLoss'] diff --git a/applications/ChatGPT/chatgpt/models/base/__init__.py b/applications/ChatGPT/chatgpt/models/base/__init__.py new file mode 100644 index 000000000..86f403556 --- /dev/null +++ b/applications/ChatGPT/chatgpt/models/base/__init__.py @@ -0,0 +1,5 @@ +from .actor import Actor +from .critic import Critic +from .reward_model import RewardModel + +__all__ = ['Actor', 'Critic', 'RewardModel'] diff --git a/applications/ChatGPT/chatgpt/nn/actor.py b/applications/ChatGPT/chatgpt/models/base/actor.py similarity index 95% rename from applications/ChatGPT/chatgpt/nn/actor.py rename to applications/ChatGPT/chatgpt/models/base/actor.py index c4c0d579d..e2841dc68 100644 --- a/applications/ChatGPT/chatgpt/nn/actor.py +++ b/applications/ChatGPT/chatgpt/models/base/actor.py @@ -4,9 +4,9 @@ import torch import torch.nn as nn import torch.nn.functional as F -from .generation import generate -from .lora import LoRAModule -from .utils import log_probs_from_logits +from ..generation import generate +from ..lora import LoRAModule +from ..utils import log_probs_from_logits class Actor(LoRAModule): diff --git a/applications/ChatGPT/chatgpt/nn/critic.py b/applications/ChatGPT/chatgpt/models/base/critic.py similarity index 95% rename from applications/ChatGPT/chatgpt/nn/critic.py rename to applications/ChatGPT/chatgpt/models/base/critic.py index f3a123854..4bff5ee97 100644 --- a/applications/ChatGPT/chatgpt/nn/critic.py +++ b/applications/ChatGPT/chatgpt/models/base/critic.py @@ -3,8 +3,8 @@ from typing import Optional import torch import torch.nn as nn -from .lora import LoRAModule -from .utils import masked_mean +from ..lora import LoRAModule +from ..utils import masked_mean class Critic(LoRAModule): diff --git a/applications/ChatGPT/chatgpt/nn/reward_model.py b/applications/ChatGPT/chatgpt/models/base/reward_model.py similarity index 97% rename from applications/ChatGPT/chatgpt/nn/reward_model.py rename to applications/ChatGPT/chatgpt/models/base/reward_model.py index 27cd1ccae..ce8c0a1d3 100644 --- a/applications/ChatGPT/chatgpt/nn/reward_model.py +++ b/applications/ChatGPT/chatgpt/models/base/reward_model.py @@ -3,7 +3,7 @@ from typing import Optional import torch import torch.nn as nn -from .lora import LoRAModule +from ..lora import LoRAModule class RewardModel(LoRAModule): diff --git a/applications/ChatGPT/chatgpt/models/bloom/__init__.py b/applications/ChatGPT/chatgpt/models/bloom/__init__.py new file mode 100644 index 000000000..d0e7f7b1e --- /dev/null +++ b/applications/ChatGPT/chatgpt/models/bloom/__init__.py @@ -0,0 +1,5 @@ +from .bloom_actor import BLOOMActor +from .bloom_critic import BLOOMCritic +from .bloom_rm import BLOOMRM + +__all__ = ['BLOOMActor', 'BLOOMCritic', 'BLOOMRM'] diff --git a/applications/ChatGPT/chatgpt/nn/bloom_actor.py b/applications/ChatGPT/chatgpt/models/bloom/bloom_actor.py similarity index 97% rename from applications/ChatGPT/chatgpt/nn/bloom_actor.py rename to applications/ChatGPT/chatgpt/models/bloom/bloom_actor.py index 103536bc3..d7577f096 100644 --- a/applications/ChatGPT/chatgpt/nn/bloom_actor.py +++ b/applications/ChatGPT/chatgpt/models/bloom/bloom_actor.py @@ -3,7 +3,7 @@ from typing import Optional import torch from transformers import BloomConfig, BloomForCausalLM, BloomModel -from .actor import Actor +from ..base import Actor class BLOOMActor(Actor): diff --git a/applications/ChatGPT/chatgpt/nn/bloom_critic.py b/applications/ChatGPT/chatgpt/models/bloom/bloom_critic.py similarity index 97% rename from applications/ChatGPT/chatgpt/nn/bloom_critic.py rename to applications/ChatGPT/chatgpt/models/bloom/bloom_critic.py index 3b03471a3..5a907309a 100644 --- a/applications/ChatGPT/chatgpt/nn/bloom_critic.py +++ b/applications/ChatGPT/chatgpt/models/bloom/bloom_critic.py @@ -4,7 +4,7 @@ import torch import torch.nn as nn from transformers import BloomConfig, BloomForCausalLM, BloomModel -from .critic import Critic +from ..base import Critic class BLOOMCritic(Critic): diff --git a/applications/ChatGPT/chatgpt/nn/bloom_rm.py b/applications/ChatGPT/chatgpt/models/bloom/bloom_rm.py similarity index 96% rename from applications/ChatGPT/chatgpt/nn/bloom_rm.py rename to applications/ChatGPT/chatgpt/models/bloom/bloom_rm.py index 12c37957d..4dc2646e3 100644 --- a/applications/ChatGPT/chatgpt/nn/bloom_rm.py +++ b/applications/ChatGPT/chatgpt/models/bloom/bloom_rm.py @@ -3,7 +3,7 @@ from typing import Optional import torch.nn as nn from transformers import BloomConfig, BloomForCausalLM, BloomModel -from .reward_model import RewardModel +from ..base import RewardModel class BLOOMRM(RewardModel): diff --git a/applications/ChatGPT/chatgpt/nn/generation.py b/applications/ChatGPT/chatgpt/models/generation.py similarity index 100% rename from applications/ChatGPT/chatgpt/nn/generation.py rename to applications/ChatGPT/chatgpt/models/generation.py diff --git a/applications/ChatGPT/chatgpt/nn/generation_utils.py b/applications/ChatGPT/chatgpt/models/generation_utils.py similarity index 100% rename from applications/ChatGPT/chatgpt/nn/generation_utils.py rename to applications/ChatGPT/chatgpt/models/generation_utils.py diff --git a/applications/ChatGPT/chatgpt/models/gpt/__init__.py b/applications/ChatGPT/chatgpt/models/gpt/__init__.py new file mode 100644 index 000000000..63dc5ab0f --- /dev/null +++ b/applications/ChatGPT/chatgpt/models/gpt/__init__.py @@ -0,0 +1,5 @@ +from .gpt_actor import GPTActor +from .gpt_critic import GPTCritic +from .gpt_rm import GPTRM + +__all__ = ['GPTActor', 'GPTCritic', 'GPTRM'] diff --git a/applications/ChatGPT/chatgpt/nn/gpt_actor.py b/applications/ChatGPT/chatgpt/models/gpt/gpt_actor.py similarity index 97% rename from applications/ChatGPT/chatgpt/nn/gpt_actor.py rename to applications/ChatGPT/chatgpt/models/gpt/gpt_actor.py index 491182ffa..da24685e1 100644 --- a/applications/ChatGPT/chatgpt/nn/gpt_actor.py +++ b/applications/ChatGPT/chatgpt/models/gpt/gpt_actor.py @@ -3,7 +3,7 @@ from typing import Optional from transformers.models.gpt2.configuration_gpt2 import GPT2Config from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel -from .actor import Actor +from ..base import Actor class GPTActor(Actor): diff --git a/applications/ChatGPT/chatgpt/nn/gpt_critic.py b/applications/ChatGPT/chatgpt/models/gpt/gpt_critic.py similarity index 97% rename from applications/ChatGPT/chatgpt/nn/gpt_critic.py rename to applications/ChatGPT/chatgpt/models/gpt/gpt_critic.py index b0a001f4a..897ddb4ae 100644 --- a/applications/ChatGPT/chatgpt/nn/gpt_critic.py +++ b/applications/ChatGPT/chatgpt/models/gpt/gpt_critic.py @@ -4,7 +4,7 @@ import torch.nn as nn from transformers.models.gpt2.configuration_gpt2 import GPT2Config from transformers.models.gpt2.modeling_gpt2 import GPT2Model -from .critic import Critic +from ..base import Critic class GPTCritic(Critic): diff --git a/applications/ChatGPT/chatgpt/nn/gpt_rm.py b/applications/ChatGPT/chatgpt/models/gpt/gpt_rm.py similarity index 96% rename from applications/ChatGPT/chatgpt/nn/gpt_rm.py rename to applications/ChatGPT/chatgpt/models/gpt/gpt_rm.py index fcfb61cd4..0132dbf27 100644 --- a/applications/ChatGPT/chatgpt/nn/gpt_rm.py +++ b/applications/ChatGPT/chatgpt/models/gpt/gpt_rm.py @@ -4,7 +4,7 @@ import torch.nn as nn from transformers.models.gpt2.configuration_gpt2 import GPT2Config from transformers.models.gpt2.modeling_gpt2 import GPT2Model -from .reward_model import RewardModel +from ..base import RewardModel class GPTRM(RewardModel): diff --git a/applications/ChatGPT/chatgpt/nn/lora.py b/applications/ChatGPT/chatgpt/models/lora.py similarity index 100% rename from applications/ChatGPT/chatgpt/nn/lora.py rename to applications/ChatGPT/chatgpt/models/lora.py diff --git a/applications/ChatGPT/chatgpt/nn/loss.py b/applications/ChatGPT/chatgpt/models/loss.py similarity index 100% rename from applications/ChatGPT/chatgpt/nn/loss.py rename to applications/ChatGPT/chatgpt/models/loss.py diff --git a/applications/ChatGPT/chatgpt/models/opt/__init__.py b/applications/ChatGPT/chatgpt/models/opt/__init__.py new file mode 100644 index 000000000..334f4df00 --- /dev/null +++ b/applications/ChatGPT/chatgpt/models/opt/__init__.py @@ -0,0 +1,5 @@ +from .opt_actor import OPTActor +from .opt_critic import OPTCritic +from .opt_rm import OPTRM + +__all__ = ['OPTActor', 'OPTCritic', 'OPTRM'] diff --git a/applications/ChatGPT/chatgpt/nn/opt_actor.py b/applications/ChatGPT/chatgpt/models/opt/opt_actor.py similarity index 97% rename from applications/ChatGPT/chatgpt/nn/opt_actor.py rename to applications/ChatGPT/chatgpt/models/opt/opt_actor.py index ff2bf7c00..c14e4377f 100644 --- a/applications/ChatGPT/chatgpt/nn/opt_actor.py +++ b/applications/ChatGPT/chatgpt/models/opt/opt_actor.py @@ -3,7 +3,7 @@ from typing import Optional from transformers.models.opt.configuration_opt import OPTConfig from transformers.models.opt.modeling_opt import OPTForCausalLM -from .actor import Actor +from ..base import Actor class OPTActor(Actor): diff --git a/applications/ChatGPT/chatgpt/nn/opt_critic.py b/applications/ChatGPT/chatgpt/models/opt/opt_critic.py similarity index 97% rename from applications/ChatGPT/chatgpt/nn/opt_critic.py rename to applications/ChatGPT/chatgpt/models/opt/opt_critic.py index 9c9cb873f..767cecb79 100644 --- a/applications/ChatGPT/chatgpt/nn/opt_critic.py +++ b/applications/ChatGPT/chatgpt/models/opt/opt_critic.py @@ -4,7 +4,7 @@ import torch.nn as nn from transformers.models.opt.configuration_opt import OPTConfig from transformers.models.opt.modeling_opt import OPTModel -from .critic import Critic +from ..base import Critic class OPTCritic(Critic): diff --git a/applications/ChatGPT/chatgpt/nn/opt_rm.py b/applications/ChatGPT/chatgpt/models/opt/opt_rm.py similarity index 96% rename from applications/ChatGPT/chatgpt/nn/opt_rm.py rename to applications/ChatGPT/chatgpt/models/opt/opt_rm.py index 5f518a3cc..7ad7b3887 100644 --- a/applications/ChatGPT/chatgpt/nn/opt_rm.py +++ b/applications/ChatGPT/chatgpt/models/opt/opt_rm.py @@ -3,7 +3,7 @@ from typing import Optional import torch.nn as nn from transformers import OPTConfig, OPTModel -from .reward_model import RewardModel +from ..base import RewardModel class OPTRM(RewardModel): diff --git a/applications/ChatGPT/chatgpt/nn/utils.py b/applications/ChatGPT/chatgpt/models/utils.py similarity index 100% rename from applications/ChatGPT/chatgpt/nn/utils.py rename to applications/ChatGPT/chatgpt/models/utils.py diff --git a/applications/ChatGPT/chatgpt/nn/__init__.py b/applications/ChatGPT/chatgpt/nn/__init__.py deleted file mode 100644 index c728d7df3..000000000 --- a/applications/ChatGPT/chatgpt/nn/__init__.py +++ /dev/null @@ -1,18 +0,0 @@ -from .actor import Actor -from .bloom_actor import BLOOMActor -from .bloom_critic import BLOOMCritic -from .bloom_rm import BLOOMRM -from .critic import Critic -from .gpt_actor import GPTActor -from .gpt_critic import GPTCritic -from .gpt_rm import GPTRM -from .loss import PairWiseLoss, PolicyLoss, PPOPtxActorLoss, ValueLoss -from .opt_actor import OPTActor -from .opt_critic import OPTCritic -from .opt_rm import OPTRM -from .reward_model import RewardModel - -__all__ = [ - 'Actor', 'Critic', 'RewardModel', 'PolicyLoss', 'ValueLoss', 'PPOPtxActorLoss', 'PairWiseLoss', 'GPTActor', - 'GPTCritic', 'GPTRM', 'BLOOMActor', 'BLOOMCritic', 'BLOOMRM', 'OPTActor', 'OPTCritic', 'OPTRM' -] diff --git a/applications/ChatGPT/chatgpt/trainer/ppo.py b/applications/ChatGPT/chatgpt/trainer/ppo.py index 2c1fd2fb6..789e0c2f8 100644 --- a/applications/ChatGPT/chatgpt/trainer/ppo.py +++ b/applications/ChatGPT/chatgpt/trainer/ppo.py @@ -2,8 +2,9 @@ from typing import Any, Callable, Dict, List, Optional import torch.nn as nn from chatgpt.experience_maker import Experience, NaiveExperienceMaker -from chatgpt.nn import Actor, Critic, PolicyLoss, ValueLoss -from chatgpt.nn.generation_utils import update_model_kwargs_fn +from chatgpt.models.base import Actor, Critic +from chatgpt.models.generation_utils import update_model_kwargs_fn +from chatgpt.models.loss import PolicyLoss, ValueLoss from chatgpt.replay_buffer import NaiveReplayBuffer from torch.optim import Optimizer diff --git a/applications/ChatGPT/chatgpt/trainer/rm.py b/applications/ChatGPT/chatgpt/trainer/rm.py index d44944aee..c07d65f84 100644 --- a/applications/ChatGPT/chatgpt/trainer/rm.py +++ b/applications/ChatGPT/chatgpt/trainer/rm.py @@ -3,7 +3,7 @@ from abc import ABC import loralib as lora import torch from chatgpt.dataset import RewardDataset -from chatgpt.nn import PairWiseLoss +from chatgpt.models.loss import PairWiseLoss from torch.optim import Adam, Optimizer from torch.utils.data import DataLoader from tqdm import tqdm diff --git a/applications/ChatGPT/chatgpt/trainer/strategies/base.py b/applications/ChatGPT/chatgpt/trainer/strategies/base.py index 2a96078e9..4347c08b4 100644 --- a/applications/ChatGPT/chatgpt/trainer/strategies/base.py +++ b/applications/ChatGPT/chatgpt/trainer/strategies/base.py @@ -5,7 +5,7 @@ from typing import Any, List, Tuple, Union import numpy as np import torch import torch.nn as nn -from chatgpt.nn import Actor +from chatgpt.models.base import Actor, Critic, RewardModel from chatgpt.replay_buffer import ReplayBuffer from torch.optim import Optimizer from torch.utils.data import DataLoader diff --git a/applications/ChatGPT/chatgpt/trainer/strategies/colossalai.py b/applications/ChatGPT/chatgpt/trainer/strategies/colossalai.py index b6ed1d451..f08018fd2 100644 --- a/applications/ChatGPT/chatgpt/trainer/strategies/colossalai.py +++ b/applications/ChatGPT/chatgpt/trainer/strategies/colossalai.py @@ -5,7 +5,7 @@ import torch import torch.distributed as dist import torch.nn as nn import torch.optim as optim -from chatgpt.nn import Actor +from chatgpt.models.base import Actor from torch.optim import Optimizer import colossalai diff --git a/applications/ChatGPT/chatgpt/trainer/strategies/ddp.py b/applications/ChatGPT/chatgpt/trainer/strategies/ddp.py index 66e99dd39..530dd998d 100644 --- a/applications/ChatGPT/chatgpt/trainer/strategies/ddp.py +++ b/applications/ChatGPT/chatgpt/trainer/strategies/ddp.py @@ -5,7 +5,7 @@ import numpy as np import torch import torch.distributed as dist import torch.nn as nn -from chatgpt.nn import Actor +from chatgpt.models.base import Actor from chatgpt.replay_buffer import ReplayBuffer from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import Optimizer diff --git a/applications/ChatGPT/examples/inference.py b/applications/ChatGPT/examples/inference.py index 239b6e19b..08885c33b 100644 --- a/applications/ChatGPT/examples/inference.py +++ b/applications/ChatGPT/examples/inference.py @@ -1,7 +1,9 @@ import argparse import torch -from chatgpt.nn import BLOOMActor, GPTActor, OPTActor +from chatgpt.models.bloom import BLOOMActor +from chatgpt.models.gpt import GPTActor +from chatgpt.models.opt import OPTActor from transformers import AutoTokenizer from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer diff --git a/applications/ChatGPT/examples/train_dummy.py b/applications/ChatGPT/examples/train_dummy.py index df64515a1..27ee7f0f1 100644 --- a/applications/ChatGPT/examples/train_dummy.py +++ b/applications/ChatGPT/examples/train_dummy.py @@ -2,7 +2,10 @@ import argparse from copy import deepcopy import torch -from chatgpt.nn import BLOOMActor, BLOOMCritic, GPTActor, GPTCritic, OPTActor, OPTCritic, RewardModel +from chatgpt.models.base import RewardModel +from chatgpt.models.bloom import BLOOMActor, BLOOMCritic +from chatgpt.models.gpt import GPTActor, GPTCritic +from chatgpt.models.opt import OPTActor, OPTCritic from chatgpt.trainer import PPOTrainer from chatgpt.trainer.callbacks import SaveCheckpoint from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy diff --git a/applications/ChatGPT/examples/train_prompts.py b/applications/ChatGPT/examples/train_prompts.py index db4c7d475..576685234 100644 --- a/applications/ChatGPT/examples/train_prompts.py +++ b/applications/ChatGPT/examples/train_prompts.py @@ -3,7 +3,10 @@ from copy import deepcopy import pandas as pd import torch -from chatgpt.nn import BLOOMActor, BLOOMCritic, GPTActor, GPTCritic, OPTActor, OPTCritic, RewardModel +from chatgpt.models.base import RewardModel +from chatgpt.models.bloom import BLOOMActor, BLOOMCritic +from chatgpt.models.gpt import GPTActor, GPTCritic +from chatgpt.models.opt import OPTActor, OPTCritic from chatgpt.trainer import PPOTrainer from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy from torch.optim import Adam diff --git a/applications/ChatGPT/examples/train_reward_model.py b/applications/ChatGPT/examples/train_reward_model.py index 44acba192..19b20b084 100644 --- a/applications/ChatGPT/examples/train_reward_model.py +++ b/applications/ChatGPT/examples/train_reward_model.py @@ -3,7 +3,10 @@ import argparse import loralib as lora import torch from chatgpt.dataset import RewardDataset -from chatgpt.nn import BLOOMRM, GPTRM, OPTRM +from chatgpt.models.base import RewardModel +from chatgpt.models.bloom import BLOOMRM +from chatgpt.models.gpt import GPTRM +from chatgpt.models.opt import OPTRM from chatgpt.trainer import RewardModelTrainer from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy from datasets import load_dataset diff --git a/applications/ChatGPT/tests/test_checkpoint.py b/applications/ChatGPT/tests/test_checkpoint.py index 6cbe51569..1bbd133f7 100644 --- a/applications/ChatGPT/tests/test_checkpoint.py +++ b/applications/ChatGPT/tests/test_checkpoint.py @@ -7,7 +7,7 @@ import pytest import torch import torch.distributed as dist import torch.multiprocessing as mp -from chatgpt.nn import GPTActor +from chatgpt.models.gpt import GPTActor from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy from transformers.models.gpt2.configuration_gpt2 import GPT2Config diff --git a/applications/ChatGPT/tests/test_data.py b/applications/ChatGPT/tests/test_data.py index b5a84c4d0..3d8fe912c 100644 --- a/applications/ChatGPT/tests/test_data.py +++ b/applications/ChatGPT/tests/test_data.py @@ -7,7 +7,8 @@ import torch import torch.distributed as dist import torch.multiprocessing as mp from chatgpt.experience_maker import NaiveExperienceMaker -from chatgpt.nn import GPTActor, GPTCritic, RewardModel +from chatgpt.models.base import RewardModel +from chatgpt.models.gpt import GPTActor, GPTCritic from chatgpt.replay_buffer import NaiveReplayBuffer from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy from transformers.models.gpt2.configuration_gpt2 import GPT2Config