mirror of https://github.com/hpcaitech/ColossalAI
[chatgpt] making experience support dp (#2971)
* [chatgpt] making experience support dp * [chatgpt] update example test ci * [chatgpt] update example test ci * [chatgpt] update example test ci * [chatgpt] update example test ci * [chatgpt] update sampler * [chatgpt] update example test ci * [chatgpt] refactor sampler * [chatgpt] update example test cipull/2986/head
parent
827a0af8cc
commit
19ad49fb3b
|
@ -1,4 +1,3 @@
|
||||||
import random
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Callable, Dict, List, Optional, Union
|
from typing import Any, Callable, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
@ -68,7 +67,7 @@ class Trainer(ABC):
|
||||||
|
|
||||||
def _sample_prompts(self, prompts) -> list:
|
def _sample_prompts(self, prompts) -> list:
|
||||||
indices = list(range(len(prompts)))
|
indices = list(range(len(prompts)))
|
||||||
sampled_indices = random.sample(indices, self.experience_batch_size)
|
sampled_indices = self.strategy.experience_sampler.choice(indices, self.experience_batch_size, replace=False)
|
||||||
return [prompts[i] for i in sampled_indices]
|
return [prompts[i] for i in sampled_indices]
|
||||||
|
|
||||||
def _learn(self):
|
def _learn(self):
|
||||||
|
@ -98,6 +97,7 @@ class Trainer(ABC):
|
||||||
|
|
||||||
def fit(self, prompts, num_episodes: int = 50000, max_timesteps: int = 500, update_timesteps: int = 5000) -> None:
|
def fit(self, prompts, num_episodes: int = 50000, max_timesteps: int = 500, update_timesteps: int = 5000) -> None:
|
||||||
time = 0
|
time = 0
|
||||||
|
sampler = self.strategy.setup_sampler(prompts)
|
||||||
self._on_fit_start()
|
self._on_fit_start()
|
||||||
for episode in range(num_episodes):
|
for episode in range(num_episodes):
|
||||||
self._on_episode_start(episode)
|
self._on_episode_start(episode)
|
||||||
|
@ -105,7 +105,7 @@ class Trainer(ABC):
|
||||||
desc=f'Episode [{episode+1}/{num_episodes}]',
|
desc=f'Episode [{episode+1}/{num_episodes}]',
|
||||||
disable=not is_rank_0()):
|
disable=not is_rank_0()):
|
||||||
time += 1
|
time += 1
|
||||||
rand_prompts = self._sample_prompts(prompts)
|
rand_prompts = sampler.sample(self.experience_batch_size)
|
||||||
if self.tokenizer is not None:
|
if self.tokenizer is not None:
|
||||||
inputs = self.tokenizer(rand_prompts)
|
inputs = self.tokenizer(rand_prompts)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -2,13 +2,16 @@ from abc import ABC, abstractmethod
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from typing import Any, List, Tuple, Union
|
from typing import Any, List, Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from chatgpt.nn import Actor, Critic, RewardModel
|
from chatgpt.nn import Actor
|
||||||
from chatgpt.replay_buffer import ReplayBuffer
|
from chatgpt.replay_buffer import ReplayBuffer
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
from .sampler import DistributedSampler
|
||||||
|
|
||||||
ModelOptimPair = Tuple[nn.Module, Optimizer]
|
ModelOptimPair = Tuple[nn.Module, Optimizer]
|
||||||
ModelOrModelOptimPair = Union[nn.Module, ModelOptimPair]
|
ModelOrModelOptimPair = Union[nn.Module, ModelOptimPair]
|
||||||
|
|
||||||
|
@ -123,3 +126,6 @@ class Strategy(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def load_optimizer(self, optimizer: Optimizer, path: str, map_location: Any = None) -> None:
|
def load_optimizer(self, optimizer: Optimizer, path: str, map_location: Any = None) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def setup_sampler(self, dataset) -> DistributedSampler:
|
||||||
|
return DistributedSampler(dataset, 1, 0)
|
||||||
|
|
|
@ -9,10 +9,11 @@ from chatgpt.nn import Actor
|
||||||
from chatgpt.replay_buffer import ReplayBuffer
|
from chatgpt.replay_buffer import ReplayBuffer
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
from torch.utils.data import DataLoader, DistributedSampler
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from .base import Strategy
|
from .base import Strategy
|
||||||
from .naive import NaiveStrategy
|
from .naive import NaiveStrategy
|
||||||
|
from .sampler import DistributedSampler
|
||||||
|
|
||||||
|
|
||||||
class DDPStrategy(NaiveStrategy):
|
class DDPStrategy(NaiveStrategy):
|
||||||
|
@ -49,15 +50,19 @@ class DDPStrategy(NaiveStrategy):
|
||||||
return DDP(model, device_ids=[device])
|
return DDP(model, device_ids=[device])
|
||||||
|
|
||||||
def setup_dataloader(self, replay_buffer: ReplayBuffer, pin_memory: bool = False) -> DataLoader:
|
def setup_dataloader(self, replay_buffer: ReplayBuffer, pin_memory: bool = False) -> DataLoader:
|
||||||
sampler = DistributedSampler(replay_buffer,
|
# DDP only mode, replay buffers on each rank are different.
|
||||||
num_replicas=dist.get_world_size(),
|
# sampler = DistributedSampler(replay_buffer,
|
||||||
rank=dist.get_rank(),
|
# num_replicas=dist.get_world_size(),
|
||||||
shuffle=True,
|
# rank=dist.get_rank(),
|
||||||
seed=self.seed,
|
# shuffle=True,
|
||||||
drop_last=True)
|
# seed=self.seed,
|
||||||
return DataLoader(replay_buffer,
|
# drop_last=True)
|
||||||
|
return DataLoader(
|
||||||
|
replay_buffer,
|
||||||
batch_size=replay_buffer.sample_batch_size,
|
batch_size=replay_buffer.sample_batch_size,
|
||||||
sampler=sampler,
|
# sampler=sampler,
|
||||||
|
shuffle=True,
|
||||||
|
drop_last=True,
|
||||||
pin_memory=pin_memory,
|
pin_memory=pin_memory,
|
||||||
collate_fn=replay_buffer.collate_fn)
|
collate_fn=replay_buffer.collate_fn)
|
||||||
|
|
||||||
|
@ -75,3 +80,6 @@ class DDPStrategy(NaiveStrategy):
|
||||||
if only_rank0 and dist.get_rank() != 0:
|
if only_rank0 and dist.get_rank() != 0:
|
||||||
return
|
return
|
||||||
super().save_optimizer(optimizer, path, only_rank0)
|
super().save_optimizer(optimizer, path, only_rank0)
|
||||||
|
|
||||||
|
def setup_sampler(self, dataset) -> DistributedSampler:
|
||||||
|
return DistributedSampler(dataset, dist.get_world_size(), dist.get_rank())
|
||||||
|
|
|
@ -0,0 +1,32 @@
|
||||||
|
import math
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class DistributedSampler:
|
||||||
|
|
||||||
|
def __init__(self, dataset, num_replicas: int, rank: int) -> None:
|
||||||
|
self.dataset = dataset
|
||||||
|
self.num_replicas = num_replicas
|
||||||
|
self.rank = rank
|
||||||
|
|
||||||
|
if len(self.dataset) % self.num_replicas != 0:
|
||||||
|
self.num_samples = math.ceil(
|
||||||
|
(len(self.dataset) - self.num_replicas) / self.num_replicas # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.num_samples = math.ceil(len(self.dataset) / self.num_replicas)
|
||||||
|
|
||||||
|
self.total_size = self.num_samples * self.num_replicas
|
||||||
|
|
||||||
|
indices = list(range(len(self.dataset)))
|
||||||
|
indices = indices[:self.total_size]
|
||||||
|
assert len(indices) == self.total_size
|
||||||
|
# subsample
|
||||||
|
indices = indices[self.rank:self.total_size:self.num_replicas]
|
||||||
|
assert len(indices) == self.num_samples
|
||||||
|
self.indices = indices
|
||||||
|
|
||||||
|
def sample(self, batch_size: int) -> list:
|
||||||
|
sampled_indices = np.random.choice(self.indices, batch_size, replace=False)
|
||||||
|
return [self.dataset[idx] for idx in sampled_indices]
|
|
@ -15,13 +15,11 @@ export OMP_NUM_THREADS=8
|
||||||
pip install -r ${BASE}/requirements.txt
|
pip install -r ${BASE}/requirements.txt
|
||||||
|
|
||||||
# train dummy
|
# train dummy
|
||||||
python ${BASE}/train_dummy.py --strategy naive --num_episodes 3 --max_timesteps 3 --update_timesteps 3 --max_epochs 3 --train_batch_size 2
|
|
||||||
for strategy in ddp colossalai_gemini colossalai_zero2; do
|
for strategy in ddp colossalai_gemini colossalai_zero2; do
|
||||||
torchrun --standalone --nproc_per_node=2 ${BASE}/train_dummy.py --strategy ${strategy} --num_episodes 3 --max_timesteps 3 --update_timesteps 3 --max_epochs 3 --train_batch_size 2
|
torchrun --standalone --nproc_per_node=2 ${BASE}/train_dummy.py --strategy ${strategy} --num_episodes 2 --max_timesteps 3 --update_timesteps 3 --max_epochs 3 --experience_batch_size 4 --train_batch_size 4
|
||||||
done
|
done
|
||||||
|
|
||||||
# train prompts
|
# train prompts
|
||||||
python ${BASE}/train_prompts.py $PROMPT_PATH --strategy naive --num_episodes 3 --max_timesteps 3 --update_timesteps 3 --max_epochs 3
|
|
||||||
for strategy in ddp colossalai_gemini colossalai_zero2; do
|
for strategy in ddp colossalai_gemini colossalai_zero2; do
|
||||||
torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py $PROMPT_PATH --strategy ${strategy} --num_episodes 3 --max_timesteps 3 --update_timesteps 3 --max_epochs 3 --train_batch_size 2
|
torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py $PROMPT_PATH --strategy ${strategy} --num_episodes 2 --max_timesteps 3 --update_timesteps 3 --max_epochs 3
|
||||||
done
|
done
|
||||||
|
|
|
@ -25,7 +25,7 @@ def main(args):
|
||||||
elif args.strategy == 'ddp':
|
elif args.strategy == 'ddp':
|
||||||
strategy = DDPStrategy()
|
strategy = DDPStrategy()
|
||||||
elif args.strategy == 'colossalai_gemini':
|
elif args.strategy == 'colossalai_gemini':
|
||||||
strategy = ColossalAIStrategy(stage=3, placement_policy='cuda')
|
strategy = ColossalAIStrategy(stage=3, placement_policy='cuda', initial_scale=2**5)
|
||||||
elif args.strategy == 'colossalai_zero2':
|
elif args.strategy == 'colossalai_zero2':
|
||||||
strategy = ColossalAIStrategy(stage=2, placement_policy='cuda')
|
strategy = ColossalAIStrategy(stage=2, placement_policy='cuda')
|
||||||
else:
|
else:
|
||||||
|
@ -82,6 +82,7 @@ def main(args):
|
||||||
critic_optim,
|
critic_optim,
|
||||||
max_epochs=args.max_epochs,
|
max_epochs=args.max_epochs,
|
||||||
train_batch_size=args.train_batch_size,
|
train_batch_size=args.train_batch_size,
|
||||||
|
experience_batch_size=args.experience_batch_size,
|
||||||
tokenizer=preprocess_batch,
|
tokenizer=preprocess_batch,
|
||||||
max_length=128,
|
max_length=128,
|
||||||
do_sample=True,
|
do_sample=True,
|
||||||
|
@ -117,6 +118,7 @@ if __name__ == '__main__':
|
||||||
parser.add_argument('--update_timesteps', type=int, default=10)
|
parser.add_argument('--update_timesteps', type=int, default=10)
|
||||||
parser.add_argument('--max_epochs', type=int, default=5)
|
parser.add_argument('--max_epochs', type=int, default=5)
|
||||||
parser.add_argument('--train_batch_size', type=int, default=8)
|
parser.add_argument('--train_batch_size', type=int, default=8)
|
||||||
|
parser.add_argument('--experience_batch_size', type=int, default=8)
|
||||||
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
|
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
main(args)
|
main(args)
|
||||||
|
|
|
@ -20,7 +20,7 @@ def main(args):
|
||||||
elif args.strategy == 'ddp':
|
elif args.strategy == 'ddp':
|
||||||
strategy = DDPStrategy()
|
strategy = DDPStrategy()
|
||||||
elif args.strategy == 'colossalai_gemini':
|
elif args.strategy == 'colossalai_gemini':
|
||||||
strategy = ColossalAIStrategy(stage=3, placement_policy='cuda')
|
strategy = ColossalAIStrategy(stage=3, placement_policy='cuda', initial_scale=2**5)
|
||||||
elif args.strategy == 'colossalai_zero2':
|
elif args.strategy == 'colossalai_zero2':
|
||||||
strategy = ColossalAIStrategy(stage=2, placement_policy='cuda')
|
strategy = ColossalAIStrategy(stage=2, placement_policy='cuda')
|
||||||
else:
|
else:
|
||||||
|
@ -83,6 +83,7 @@ def main(args):
|
||||||
critic_optim,
|
critic_optim,
|
||||||
max_epochs=args.max_epochs,
|
max_epochs=args.max_epochs,
|
||||||
train_batch_size=args.train_batch_size,
|
train_batch_size=args.train_batch_size,
|
||||||
|
experience_batch_size=args.experience_batch_size,
|
||||||
tokenizer=tokenize_fn,
|
tokenizer=tokenize_fn,
|
||||||
max_length=128,
|
max_length=128,
|
||||||
do_sample=True,
|
do_sample=True,
|
||||||
|
@ -117,6 +118,7 @@ if __name__ == '__main__':
|
||||||
parser.add_argument('--update_timesteps', type=int, default=10)
|
parser.add_argument('--update_timesteps', type=int, default=10)
|
||||||
parser.add_argument('--max_epochs', type=int, default=5)
|
parser.add_argument('--max_epochs', type=int, default=5)
|
||||||
parser.add_argument('--train_batch_size', type=int, default=8)
|
parser.add_argument('--train_batch_size', type=int, default=8)
|
||||||
|
parser.add_argument('--experience_batch_size', type=int, default=8)
|
||||||
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
|
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
main(args)
|
main(args)
|
||||||
|
|
|
@ -107,6 +107,7 @@ def run_dist(rank, world_size, port, strategy):
|
||||||
run_test_data(strategy)
|
run_test_data(strategy)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@pytest.mark.parametrize('world_size', [2])
|
@pytest.mark.parametrize('world_size', [2])
|
||||||
@pytest.mark.parametrize('strategy', ['ddp', 'colossalai'])
|
@pytest.mark.parametrize('strategy', ['ddp', 'colossalai'])
|
||||||
|
|
Loading…
Reference in New Issue