[chatgpt] making experience support dp (#2971)

* [chatgpt] making experience support dp

* [chatgpt] update example test ci

* [chatgpt] update example test ci

* [chatgpt] update example test ci

* [chatgpt] update example test ci

* [chatgpt] update sampler

* [chatgpt] update example test ci

* [chatgpt] refactor sampler

* [chatgpt] update example test ci
pull/2986/head
ver217 2023-03-03 15:51:19 +08:00 committed by GitHub
parent 827a0af8cc
commit 19ad49fb3b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 71 additions and 22 deletions

View File

@ -1,4 +1,3 @@
import random
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List, Optional, Union
@ -68,7 +67,7 @@ class Trainer(ABC):
def _sample_prompts(self, prompts) -> list:
indices = list(range(len(prompts)))
sampled_indices = random.sample(indices, self.experience_batch_size)
sampled_indices = self.strategy.experience_sampler.choice(indices, self.experience_batch_size, replace=False)
return [prompts[i] for i in sampled_indices]
def _learn(self):
@ -98,6 +97,7 @@ class Trainer(ABC):
def fit(self, prompts, num_episodes: int = 50000, max_timesteps: int = 500, update_timesteps: int = 5000) -> None:
time = 0
sampler = self.strategy.setup_sampler(prompts)
self._on_fit_start()
for episode in range(num_episodes):
self._on_episode_start(episode)
@ -105,7 +105,7 @@ class Trainer(ABC):
desc=f'Episode [{episode+1}/{num_episodes}]',
disable=not is_rank_0()):
time += 1
rand_prompts = self._sample_prompts(prompts)
rand_prompts = sampler.sample(self.experience_batch_size)
if self.tokenizer is not None:
inputs = self.tokenizer(rand_prompts)
else:

View File

@ -2,13 +2,16 @@ from abc import ABC, abstractmethod
from contextlib import nullcontext
from typing import Any, List, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
from chatgpt.nn import Actor, Critic, RewardModel
from chatgpt.nn import Actor
from chatgpt.replay_buffer import ReplayBuffer
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from .sampler import DistributedSampler
ModelOptimPair = Tuple[nn.Module, Optimizer]
ModelOrModelOptimPair = Union[nn.Module, ModelOptimPair]
@ -123,3 +126,6 @@ class Strategy(ABC):
@abstractmethod
def load_optimizer(self, optimizer: Optimizer, path: str, map_location: Any = None) -> None:
pass
def setup_sampler(self, dataset) -> DistributedSampler:
return DistributedSampler(dataset, 1, 0)

View File

@ -9,10 +9,11 @@ from chatgpt.nn import Actor
from chatgpt.replay_buffer import ReplayBuffer
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler
from torch.utils.data import DataLoader
from .base import Strategy
from .naive import NaiveStrategy
from .sampler import DistributedSampler
class DDPStrategy(NaiveStrategy):
@ -49,17 +50,21 @@ class DDPStrategy(NaiveStrategy):
return DDP(model, device_ids=[device])
def setup_dataloader(self, replay_buffer: ReplayBuffer, pin_memory: bool = False) -> DataLoader:
sampler = DistributedSampler(replay_buffer,
num_replicas=dist.get_world_size(),
rank=dist.get_rank(),
shuffle=True,
seed=self.seed,
drop_last=True)
return DataLoader(replay_buffer,
batch_size=replay_buffer.sample_batch_size,
sampler=sampler,
pin_memory=pin_memory,
collate_fn=replay_buffer.collate_fn)
# DDP only mode, replay buffers on each rank are different.
# sampler = DistributedSampler(replay_buffer,
# num_replicas=dist.get_world_size(),
# rank=dist.get_rank(),
# shuffle=True,
# seed=self.seed,
# drop_last=True)
return DataLoader(
replay_buffer,
batch_size=replay_buffer.sample_batch_size,
# sampler=sampler,
shuffle=True,
drop_last=True,
pin_memory=pin_memory,
collate_fn=replay_buffer.collate_fn)
@staticmethod
def _unwrap_actor(actor: Actor) -> nn.Module:
@ -75,3 +80,6 @@ class DDPStrategy(NaiveStrategy):
if only_rank0 and dist.get_rank() != 0:
return
super().save_optimizer(optimizer, path, only_rank0)
def setup_sampler(self, dataset) -> DistributedSampler:
return DistributedSampler(dataset, dist.get_world_size(), dist.get_rank())

View File

@ -0,0 +1,32 @@
import math
import numpy as np
class DistributedSampler:
def __init__(self, dataset, num_replicas: int, rank: int) -> None:
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
if len(self.dataset) % self.num_replicas != 0:
self.num_samples = math.ceil(
(len(self.dataset) - self.num_replicas) / self.num_replicas # type: ignore[arg-type]
)
else:
self.num_samples = math.ceil(len(self.dataset) / self.num_replicas)
self.total_size = self.num_samples * self.num_replicas
indices = list(range(len(self.dataset)))
indices = indices[:self.total_size]
assert len(indices) == self.total_size
# subsample
indices = indices[self.rank:self.total_size:self.num_replicas]
assert len(indices) == self.num_samples
self.indices = indices
def sample(self, batch_size: int) -> list:
sampled_indices = np.random.choice(self.indices, batch_size, replace=False)
return [self.dataset[idx] for idx in sampled_indices]

View File

@ -15,13 +15,11 @@ export OMP_NUM_THREADS=8
pip install -r ${BASE}/requirements.txt
# train dummy
python ${BASE}/train_dummy.py --strategy naive --num_episodes 3 --max_timesteps 3 --update_timesteps 3 --max_epochs 3 --train_batch_size 2
for strategy in ddp colossalai_gemini colossalai_zero2; do
torchrun --standalone --nproc_per_node=2 ${BASE}/train_dummy.py --strategy ${strategy} --num_episodes 3 --max_timesteps 3 --update_timesteps 3 --max_epochs 3 --train_batch_size 2
torchrun --standalone --nproc_per_node=2 ${BASE}/train_dummy.py --strategy ${strategy} --num_episodes 2 --max_timesteps 3 --update_timesteps 3 --max_epochs 3 --experience_batch_size 4 --train_batch_size 4
done
# train prompts
python ${BASE}/train_prompts.py $PROMPT_PATH --strategy naive --num_episodes 3 --max_timesteps 3 --update_timesteps 3 --max_epochs 3
for strategy in ddp colossalai_gemini colossalai_zero2; do
torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py $PROMPT_PATH --strategy ${strategy} --num_episodes 3 --max_timesteps 3 --update_timesteps 3 --max_epochs 3 --train_batch_size 2
torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py $PROMPT_PATH --strategy ${strategy} --num_episodes 2 --max_timesteps 3 --update_timesteps 3 --max_epochs 3
done

View File

@ -25,7 +25,7 @@ def main(args):
elif args.strategy == 'ddp':
strategy = DDPStrategy()
elif args.strategy == 'colossalai_gemini':
strategy = ColossalAIStrategy(stage=3, placement_policy='cuda')
strategy = ColossalAIStrategy(stage=3, placement_policy='cuda', initial_scale=2**5)
elif args.strategy == 'colossalai_zero2':
strategy = ColossalAIStrategy(stage=2, placement_policy='cuda')
else:
@ -82,6 +82,7 @@ def main(args):
critic_optim,
max_epochs=args.max_epochs,
train_batch_size=args.train_batch_size,
experience_batch_size=args.experience_batch_size,
tokenizer=preprocess_batch,
max_length=128,
do_sample=True,
@ -117,6 +118,7 @@ if __name__ == '__main__':
parser.add_argument('--update_timesteps', type=int, default=10)
parser.add_argument('--max_epochs', type=int, default=5)
parser.add_argument('--train_batch_size', type=int, default=8)
parser.add_argument('--experience_batch_size', type=int, default=8)
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
args = parser.parse_args()
main(args)

View File

@ -20,7 +20,7 @@ def main(args):
elif args.strategy == 'ddp':
strategy = DDPStrategy()
elif args.strategy == 'colossalai_gemini':
strategy = ColossalAIStrategy(stage=3, placement_policy='cuda')
strategy = ColossalAIStrategy(stage=3, placement_policy='cuda', initial_scale=2**5)
elif args.strategy == 'colossalai_zero2':
strategy = ColossalAIStrategy(stage=2, placement_policy='cuda')
else:
@ -83,6 +83,7 @@ def main(args):
critic_optim,
max_epochs=args.max_epochs,
train_batch_size=args.train_batch_size,
experience_batch_size=args.experience_batch_size,
tokenizer=tokenize_fn,
max_length=128,
do_sample=True,
@ -117,6 +118,7 @@ if __name__ == '__main__':
parser.add_argument('--update_timesteps', type=int, default=10)
parser.add_argument('--max_epochs', type=int, default=5)
parser.add_argument('--train_batch_size', type=int, default=8)
parser.add_argument('--experience_batch_size', type=int, default=8)
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
args = parser.parse_args()
main(args)

View File

@ -107,6 +107,7 @@ def run_dist(rank, world_size, port, strategy):
run_test_data(strategy)
@pytest.mark.skip
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [2])
@pytest.mark.parametrize('strategy', ['ddp', 'colossalai'])