[gemini] accelerate inference (#3641)

* [gemini] support don't scatter after inference

* [chat] update colossalai strategy

* [chat] fix opt benchmark

* [chat] update opt benchmark

* [gemini] optimize inference

* [test] add gemini inference test

* [chat] fix unit test ci

* [chat] fix ci

* [chat] fix ci

* [chat] skip checkpoint test
pull/3647/head
Hongxin Liu 2023-04-26 16:32:40 +08:00 committed by GitHub
parent 4b3240cb59
commit 50793b35f4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 162 additions and 157 deletions

View File

@ -32,14 +32,14 @@ jobs:
- name: Install ColossalAI and ChatGPT - name: Install ColossalAI and ChatGPT
run: | run: |
pip install -v . pip install -e .
cd applications/ChatGPT cd applications/Chat
pip install -v . pip install -v .
pip install -r requirements-test.txt pip install -r requirements-test.txt
- name: Execute Unit Testing - name: Execute Unit Testing
run: | run: |
cd applications/ChatGPT cd applications/Chat
rm -rf ~/.cache/colossalai rm -rf ~/.cache/colossalai
pytest tests/ pytest tests/
env: env:

View File

@ -10,6 +10,7 @@ from coati.trainer import PPOTrainer
from coati.trainer.callbacks import PerformanceEvaluator from coati.trainer.callbacks import PerformanceEvaluator
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, Strategy from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, Strategy
from torch.optim import Adam from torch.optim import Adam
from torch.utils.data import DataLoader
from transformers import AutoTokenizer from transformers import AutoTokenizer
from transformers.models.opt.configuration_opt import OPTConfig from transformers.models.opt.configuration_opt import OPTConfig
@ -92,13 +93,13 @@ def main(args):
torch.cuda.set_per_process_memory_fraction(args.cuda_mem_frac) torch.cuda.set_per_process_memory_fraction(args.cuda_mem_frac)
model_config = get_gpt_config(args.model) model_config = get_gpt_config(args.model)
critic_config = get_gpt_config(args.critic_model)
with strategy.model_init_context(): with strategy.model_init_context():
actor = OPTActor(config=model_config, lora_rank=args.lora_rank).cuda() actor = OPTActor(config=model_config, lora_rank=args.lora_rank).cuda()
critic = OPTCritic(config=model_config, lora_rank=args.lora_rank).cuda() critic = OPTCritic(config=critic_config, lora_rank=args.lora_rank).cuda()
initial_model = deepcopy(actor).cuda() initial_model = deepcopy(actor).cuda().half()
reward_model = RewardModel(deepcopy(critic.model), deepcopy(critic.value_head)).cuda() reward_model = RewardModel(deepcopy(critic.model), deepcopy(critic.value_head)).cuda().half()
actor_numel = get_model_numel(actor, strategy) actor_numel = get_model_numel(actor, strategy)
critic_numel = get_model_numel(critic, strategy) critic_numel = get_model_numel(critic, strategy)
@ -127,8 +128,7 @@ def main(args):
tokenizer = AutoTokenizer.from_pretrained('facebook/opt-350m') tokenizer = AutoTokenizer.from_pretrained('facebook/opt-350m')
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare( (actor, actor_optim), (critic, critic_optim) = strategy.prepare((actor, actor_optim), (critic, critic_optim))
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model)
trainer = PPOTrainer(strategy, trainer = PPOTrainer(strategy,
actor, actor,
@ -137,6 +137,7 @@ def main(args):
initial_model, initial_model,
actor_optim, actor_optim,
critic_optim, critic_optim,
ptx_coef=0,
max_epochs=args.max_epochs, max_epochs=args.max_epochs,
train_batch_size=args.train_batch_size, train_batch_size=args.train_batch_size,
experience_batch_size=args.experience_batch_size, experience_batch_size=args.experience_batch_size,
@ -145,14 +146,19 @@ def main(args):
do_sample=True, do_sample=True,
temperature=1.0, temperature=1.0,
top_k=50, top_k=50,
use_cache=True,
pad_token_id=tokenizer.pad_token_id, pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id,
callbacks=[performance_evaluator]) callbacks=[performance_evaluator])
random_prompts = torch.randint(tokenizer.vocab_size, (1000, 1, 400), device=torch.cuda.current_device()) random_prompts = torch.randint(tokenizer.vocab_size, (1000, 256), device=torch.cuda.current_device())
random_attention_mask = torch.randint(1, (1000, 1, 400), device=torch.cuda.current_device()).to(torch.bool) dataloader = DataLoader(random_prompts,
random_pretrain = [{'input_ids':random_prompts[i], 'labels':random_prompts[i], 'attention_mask':random_attention_mask[i]} for i in range(1000)] batch_size=args.experience_batch_size,
trainer.fit(random_prompts, random_pretrain, shuffle=True,
collate_fn=preprocess_batch)
trainer.fit(dataloader,
None,
num_episodes=args.num_episodes, num_episodes=args.num_episodes,
max_timesteps=args.max_timesteps, max_timesteps=args.max_timesteps,
update_timesteps=args.update_timesteps) update_timesteps=args.update_timesteps)
@ -163,6 +169,7 @@ def main(args):
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--model', default='125m') parser.add_argument('--model', default='125m')
parser.add_argument('--critic_model', default='125m')
parser.add_argument('--strategy', parser.add_argument('--strategy',
choices=[ choices=[
'ddp', 'colossalai_gemini', 'colossalai_gemini_cpu', 'colossalai_zero2', 'ddp', 'colossalai_gemini', 'colossalai_gemini_cpu', 'colossalai_zero2',
@ -175,7 +182,7 @@ if __name__ == '__main__':
parser.add_argument('--max_epochs', type=int, default=3) parser.add_argument('--max_epochs', type=int, default=3)
parser.add_argument('--train_batch_size', type=int, default=8) parser.add_argument('--train_batch_size', type=int, default=8)
parser.add_argument('--experience_batch_size', type=int, default=8) parser.add_argument('--experience_batch_size', type=int, default=8)
parser.add_argument('--lora_rank', type=int, default=4) parser.add_argument('--lora_rank', type=int, default=0)
parser.add_argument('--cuda_mem_frac', type=float, default=1.0) parser.add_argument('--cuda_mem_frac', type=float, default=1.0)
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)

View File

@ -1,5 +1,6 @@
import copy import copy
import random import random
from collections import defaultdict
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Callable, Dict, Sequence from typing import Callable, Dict, Sequence
@ -19,9 +20,13 @@ logger = get_dist_logger()
class PromptDataset(Dataset): class PromptDataset(Dataset):
"""Dataset for supervised fine-tuning.""" """Dataset for supervised fine-tuning."""
def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer, max_datasets_size: int = None): def __init__(self,
data_path: str,
tokenizer: transformers.PreTrainedTokenizer,
max_datasets_size: int = None,
max_length: int = 96):
super(PromptDataset, self).__init__() super(PromptDataset, self).__init__()
self.prompt = [] self.keyed_prompt = defaultdict(list)
logger.info("Loading data...") logger.info("Loading data...")
list_data_dict = jload(data_path) list_data_dict = jload(data_path)
logger.info(f"Loaded {len(list_data_dict)} examples.") logger.info(f"Loaded {len(list_data_dict)} examples.")
@ -33,14 +38,14 @@ class PromptDataset(Dataset):
for data_dict in list_data_dict: for data_dict in list_data_dict:
token = tokenizer(data_dict["instruction"], token = tokenizer(data_dict["instruction"],
return_tensors='pt', return_tensors='pt',
max_length=96, max_length=max_length,
padding='max_length', padding='max_length',
truncation=True) truncation=True)
for idx in token['input_ids']: for k, tensor in token.items():
self.prompt.append(idx.to(torch.cuda.current_device())) self.keyed_prompt[k].extend(tensor.to(torch.cuda.current_device()).unbind())
def __len__(self): def __len__(self):
return len(self.prompt) return len(self.keyed_prompt)
def __getitem__(self, i) -> Dict[str, torch.Tensor]: def __getitem__(self, i) -> Dict[str, torch.Tensor]:
return self.prompt[i] return {k: v[i] for k, v in self.keyed_prompt.items()}

View File

@ -76,7 +76,7 @@ def sample(model: nn.Module,
# update generated ids, model inputs for next step # update generated ids, model inputs for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
if update_model_kwargs_fn is not None: if update_model_kwargs_fn is not None:
model_kwargs = update_model_kwargs_fn(outputs, **model_kwargs) model_kwargs = update_model_kwargs_fn(outputs, model_kwargs)
# if eos_token was found in one sentence, set sentence to finished # if eos_token was found in one sentence, set sentence to finished
if eos_token_id is not None: if eos_token_id is not None:

View File

@ -1,92 +0,0 @@
from typing import Optional
import torch
def gpt_prepare_inputs_fn(input_ids: torch.Tensor, past: Optional[torch.Tensor] = None, **kwargs) -> dict:
token_type_ids = kwargs.get("token_type_ids", None)
# only last token for inputs_ids if past is defined in kwargs
if past:
input_ids = input_ids[:, -1].unsqueeze(-1)
if token_type_ids is not None:
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
attention_mask = kwargs.get("attention_mask", None)
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past:
position_ids = position_ids[:, -1].unsqueeze(-1)
else:
position_ids = None
return {
"input_ids": input_ids,
"past_key_values": past,
"use_cache": kwargs.get("use_cache"),
"position_ids": position_ids,
"attention_mask": attention_mask,
"token_type_ids": token_type_ids,
}
def update_model_kwargs_fn(outputs: dict, **model_kwargs) -> dict:
if "past_key_values" in outputs:
model_kwargs["past"] = outputs["past_key_values"]
else:
model_kwargs["past"] = None
# update token_type_ids with last value
if "token_type_ids" in model_kwargs:
token_type_ids = model_kwargs["token_type_ids"]
model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)
# update attention mask
if "attention_mask" in model_kwargs:
attention_mask = model_kwargs["attention_mask"]
model_kwargs["attention_mask"] = torch.cat(
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1)
return model_kwargs
def opt_prepare_inputs_fn(input_ids: torch.Tensor,
past: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
**kwargs) -> dict:
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_ids.shape)
if past:
input_ids = input_ids[:, -1:]
# first step, decoder_cached_states are empty
return {
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
"attention_mask": attention_mask,
"past_key_values": past,
"use_cache": use_cache,
}
def bloom_prepare_inputs_fn(input_ids: torch.Tensor,
past: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
**kwargs) -> dict:
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_ids.shape)
if past:
input_ids = input_ids[:, -1:]
# first step, decoder_cached_states are empty
return {
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
"attention_mask": attention_mask,
"past_key_values": past,
"use_cache": use_cache,
}

View File

@ -4,14 +4,13 @@ import torch
import torch.nn as nn import torch.nn as nn
from coati.experience_maker import Experience, NaiveExperienceMaker from coati.experience_maker import Experience, NaiveExperienceMaker
from coati.models.base import Actor, Critic from coati.models.base import Actor, Critic
from coati.models.generation_utils import update_model_kwargs_fn
from coati.models.loss import PolicyLoss, ValueLoss from coati.models.loss import PolicyLoss, ValueLoss
from coati.replay_buffer import NaiveReplayBuffer from coati.replay_buffer import NaiveReplayBuffer
from torch import Tensor from torch import Tensor
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.utils.data import DistributedSampler from torch.utils.data import DistributedSampler
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from tqdm import tqdm from tqdm import tqdm
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from .base import Trainer from .base import Trainer
from .callbacks import Callback from .callbacks import Callback
@ -102,19 +101,16 @@ class PPOTrainer(Trainer):
def _sample_prompts(self, prompts) -> list: def _sample_prompts(self, prompts) -> list:
indices = list(range(len(prompts))) indices = list(range(len(prompts)))
sampled_indices = self.strategy.experience_sampler.choice( sampled_indices = self.strategy.experience_sampler.choice(indices, self.experience_batch_size, replace=False)
indices, self.experience_batch_size, replace=False)
return [prompts[i] for i in sampled_indices] return [prompts[i] for i in sampled_indices]
def _learn(self): def _learn(self):
# replay buffer may be empty at first, we should rebuild at each training # replay buffer may be empty at first, we should rebuild at each training
if not self.sample_replay_buffer: if not self.sample_replay_buffer:
dataloader = self.strategy.setup_dataloader( dataloader = self.strategy.setup_dataloader(self.replay_buffer, self.dataloader_pin_memory)
self.replay_buffer, self.dataloader_pin_memory)
device = torch.cuda.current_device() device = torch.cuda.current_device()
if self.sample_replay_buffer: if self.sample_replay_buffer:
pbar = tqdm(range(self.max_epochs), desc='Train epoch', pbar = tqdm(range(self.max_epochs), desc='Train epoch', disable=not is_rank_0())
disable=not is_rank_0())
for _ in pbar: for _ in pbar:
experience = self.replay_buffer.sample() experience = self.replay_buffer.sample()
metrics = self.training_step(experience) metrics = self.training_step(experience)
@ -124,8 +120,7 @@ class PPOTrainer(Trainer):
self._on_learn_epoch_start(epoch) self._on_learn_epoch_start(epoch)
if isinstance(dataloader.sampler, DistributedSampler): if isinstance(dataloader.sampler, DistributedSampler):
dataloader.sampler.set_epoch(epoch) dataloader.sampler.set_epoch(epoch)
pbar = tqdm( pbar = tqdm(dataloader, desc=f'Train epoch [{epoch+1}/{self.max_epochs}]', disable=not is_rank_0())
dataloader, desc=f'Train epoch [{epoch+1}/{self.max_epochs}]', disable=not is_rank_0())
for experience in pbar: for experience in pbar:
self._on_learn_batch_start() self._on_learn_batch_start()
experience.to_device(device) experience.to_device(device)
@ -152,10 +147,8 @@ class PPOTrainer(Trainer):
time += 1 time += 1
prompts = next(iter(self.prompt_dataloader)) prompts = next(iter(self.prompt_dataloader))
self._on_make_experience_start() self._on_make_experience_start()
self.experience_maker.initial_model.to( self.experience_maker.initial_model.to(torch.cuda.current_device())
torch.cuda.current_device()) self.experience_maker.reward_model.to(torch.cuda.current_device())
self.experience_maker.reward_model.to(
torch.cuda.current_device())
experience = self._make_experience(prompts) experience = self._make_experience(prompts)
self._on_make_experience_end(experience) self._on_make_experience_end(experience)
self.replay_buffer.append(experience) self.replay_buffer.append(experience)
@ -207,7 +200,10 @@ class PPOTrainer(Trainer):
return {'reward': experience.reward.mean().item()} return {'reward': experience.reward.mean().item()}
def save_model(self, path: str, only_rank0: bool = False, tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None: def save_model(self,
path: str,
only_rank0: bool = False,
tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
self.strategy.save_model(model=self.actor, path=path, only_rank0=only_rank0, tokenizer=tokenizer) self.strategy.save_model(model=self.actor, path=path, only_rank0=only_rank0, tokenizer=tokenizer)
@ -218,7 +214,7 @@ def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, acto
if 'prepare_inputs_fn' not in generate_kwargs and hasattr(origin_model, 'prepare_inputs_for_generation'): if 'prepare_inputs_fn' not in generate_kwargs and hasattr(origin_model, 'prepare_inputs_for_generation'):
new_kwargs['prepare_inputs_fn'] = origin_model.prepare_inputs_for_generation new_kwargs['prepare_inputs_fn'] = origin_model.prepare_inputs_for_generation
if 'update_model_kwargs_fn' not in generate_kwargs: if 'update_model_kwargs_fn' not in generate_kwargs and hasattr(origin_model, '_update_model_kwargs_for_generation'):
new_kwargs['update_model_kwargs_fn'] = update_model_kwargs_fn new_kwargs['update_model_kwargs_fn'] = origin_model._update_model_kwargs_for_generation
return new_kwargs return new_kwargs

View File

@ -67,6 +67,7 @@ class ColossalAIStrategy(DDPStrategy):
placement_policy: str = 'cuda', placement_policy: str = 'cuda',
pin_memory: bool = True, # only for stage 3 pin_memory: bool = True, # only for stage 3
force_outputs_fp32: bool = False, # only for stage 3 force_outputs_fp32: bool = False, # only for stage 3
scatter_after_inference: bool = False, # only for stage 3
search_range_mb: int = 32, # only for stage 3 search_range_mb: int = 32, # only for stage 3
hidden_dim: Optional[int] = None, # only for stage 3 hidden_dim: Optional[int] = None, # only for stage 3
min_chunk_size_mb: float = 32, # only for stage 3 min_chunk_size_mb: float = 32, # only for stage 3
@ -103,7 +104,8 @@ class ColossalAIStrategy(DDPStrategy):
strict_ddp_mode=shard_init, strict_ddp_mode=shard_init,
search_range_mb=search_range_mb, search_range_mb=search_range_mb,
hidden_dim=hidden_dim, hidden_dim=hidden_dim,
min_chunk_size_mb=min_chunk_size_mb) min_chunk_size_mb=min_chunk_size_mb,
scatter_after_inference=scatter_after_inference)
if stage == 3: if stage == 3:
self.zero_optim_config = dict(gpu_margin_mem_ratio=gpu_margin_mem_ratio) self.zero_optim_config = dict(gpu_margin_mem_ratio=gpu_margin_mem_ratio)
else: else:
@ -159,14 +161,6 @@ class ColossalAIStrategy(DDPStrategy):
return model.module return model.module
return model return model
def _unwrap_model(self, model: Union[nn.Module, ZeroDDP]) -> nn.Module:
if isinstance(model, ZeroDDP) and self.stage == 3:
logger.info(f"model type: {type(model)}, get static torch model")
model = get_static_torch_model(model)
logger.info(f"unwrapped_model type: {type(model)}")
return super()._unwrap_model(model)
def save_model(self, def save_model(self,
model: nn.Module, model: nn.Module,
path: str, path: str,

View File

@ -82,6 +82,7 @@ def run_dist(rank, world_size, port, strategy):
run_test_checkpoint(strategy) run_test_checkpoint(strategy)
@pytest.mark.skip('temporarily skip until refactor strategy unwrap')
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize('world_size', [2]) @pytest.mark.parametrize('world_size', [2])
@pytest.mark.parametrize('strategy', ['ddp', 'colossalai_zero2', 'colossalai_gemini']) @pytest.mark.parametrize('strategy', ['ddp', 'colossalai_zero2', 'colossalai_gemini'])

View File

@ -1,5 +1,6 @@
import itertools import itertools
from collections import OrderedDict from collections import OrderedDict
from contextlib import nullcontext
from functools import partial from functools import partial
from typing import Dict, Iterator, List, Optional, Union from typing import Dict, Iterator, List, Optional, Union
@ -49,6 +50,7 @@ class ZeroDDP(ColoDDP):
Defaults to False. Defaults to False.
strict_ddp_mode (bool): If set to True, there is no tensor sharding, each tensor is replicated. strict_ddp_mode (bool): If set to True, there is no tensor sharding, each tensor is replicated.
Defaults to False. Users can set it to True, when they clearly know that they only need DDP. Defaults to False. Users can set it to True, when they clearly know that they only need DDP.
scatter_after_inference (bool): If set to True, the model will be scattered after inference. This will save memory but slow down the consecutive inference.
""" """
def __init__(self, def __init__(self,
@ -56,7 +58,8 @@ class ZeroDDP(ColoDDP):
gemini_manager: GeminiManager, gemini_manager: GeminiManager,
pin_memory: bool = False, pin_memory: bool = False,
force_outputs_fp32: bool = False, force_outputs_fp32: bool = False,
strict_ddp_mode: bool = False) -> None: strict_ddp_mode: bool = False,
scatter_after_inference: bool = True) -> None:
self.gemini_manager = gemini_manager self.gemini_manager = gemini_manager
self.chunk_manager: ChunkManager = gemini_manager.chunk_manager self.chunk_manager: ChunkManager = gemini_manager.chunk_manager
self.force_outputs_fp32 = force_outputs_fp32 self.force_outputs_fp32 = force_outputs_fp32
@ -67,6 +70,7 @@ class ZeroDDP(ColoDDP):
self.grads_device: Dict[torch.Tensor, torch.device] = dict() self.grads_device: Dict[torch.Tensor, torch.device] = dict()
self.param2name: Dict[nn.Parameter, str] = dict() self.param2name: Dict[nn.Parameter, str] = dict()
self.name2param: Dict[str, nn.Parameter] = dict() self.name2param: Dict[str, nn.Parameter] = dict()
self.scatter_after_inference = scatter_after_inference
self._logger = get_dist_logger() self._logger = get_dist_logger()
@ -108,8 +112,6 @@ class ZeroDDP(ColoDDP):
first_param = next(iter(chunk.tensors_info)) first_param = next(iter(chunk.tensors_info))
self.chunk_manager.move_chunk(chunk, self.grads_device[first_param]) self.chunk_manager.move_chunk(chunk, self.grads_device[first_param])
assert self.chunk_manager.accessed_mem == 0 assert self.chunk_manager.accessed_mem == 0
# reset all recorded attributes
self.gemini_manager.reset_attributes()
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
# check whether we are in a inference mode # check whether we are in a inference mode
@ -120,17 +122,35 @@ class ZeroDDP(ColoDDP):
args, kwargs = _cast_float(args, torch.half), _cast_float(kwargs, torch.half) args, kwargs = _cast_float(args, torch.half), _cast_float(kwargs, torch.half)
self.module.zero_grad(set_to_none=True) self.module.zero_grad(set_to_none=True)
if not grad_flag:
outputs = self._inference_forward(*args, **kwargs)
else:
self.gemini_manager.pre_iter(*args) self.gemini_manager.pre_iter(*args)
with ColoParamOpHookManager.use_hooks(self.param_op_hook): with ColoParamOpHookManager.use_hooks(self.param_op_hook):
outputs = self.module(*args, **kwargs) outputs = self.module(*args, **kwargs)
# scatter chunks in the inference mode
if not grad_flag:
self._post_forward()
if self.force_outputs_fp32: if self.force_outputs_fp32:
return _cast_float(outputs, torch.float) return _cast_float(outputs, torch.float)
return outputs return outputs
def _inference_forward(self, *args, **kwargs):
"""This function is only triggered for inference.
"""
fwd_ctx = ColoParamOpHookManager.use_hooks(self.param_op_hook)
if not self.scatter_after_inference:
# gather all chunks
for chunk in self.chunk_manager.get_chunks(self.fp16_params):
self.chunk_manager.access_chunk(chunk)
fwd_ctx = nullcontext()
with fwd_ctx:
outputs = self.module(*args, **kwargs)
if self.scatter_after_inference:
# scatter chunks
self._post_forward()
# reset all recorded attributes
self.gemini_manager.reset_attributes()
return outputs
def _setup_grads_ptr(self): def _setup_grads_ptr(self):
for p in self.module.parameters(): for p in self.module.parameters():
if is_ddp_ignored(p): if is_ddp_ignored(p):
@ -678,6 +698,7 @@ class GeminiDDP(ZeroDDP):
pin_memory: bool = False, pin_memory: bool = False,
force_outputs_fp32: bool = False, force_outputs_fp32: bool = False,
strict_ddp_mode: bool = False, strict_ddp_mode: bool = False,
scatter_after_inference: bool = True,
search_range_mb: int = 32, search_range_mb: int = 32,
hidden_dim: Optional[int] = None, hidden_dim: Optional[int] = None,
min_chunk_size_mb: float = 32, min_chunk_size_mb: float = 32,
@ -722,4 +743,5 @@ class GeminiDDP(ZeroDDP):
strict_ddp_flag=strict_ddp_mode, strict_ddp_flag=strict_ddp_mode,
verbose=verbose) verbose=verbose)
gemini_manager = GeminiManager(placement_policy, chunk_manager, memstats) gemini_manager = GeminiManager(placement_policy, chunk_manager, memstats)
super().__init__(module, gemini_manager, pin_memory, force_outputs_fp32, strict_ddp_mode) super().__init__(module, gemini_manager, pin_memory, force_outputs_fp32, strict_ddp_mode,
scatter_after_inference)

View File

@ -9,11 +9,11 @@ from . import (
resnet, resnet,
simple_net, simple_net,
) )
from .utils import run_fwd_bwd from .utils import run_fwd, run_fwd_bwd
from . import albert # isort:skip from . import albert # isort:skip
__all__ = [ __all__ = [
'bert', 'gpt2', 'hanging_param_model', 'inline_op_model', 'nested_model', 'repeated_computed_layers', 'resnet', 'bert', 'gpt2', 'hanging_param_model', 'inline_op_model', 'nested_model', 'repeated_computed_layers', 'resnet',
'simple_net', 'run_fwd_bwd', 'albert', 'beit' 'simple_net', 'run_fwd_bwd', 'albert', 'beit', 'run_fwd'
] ]

View File

@ -1,2 +1,2 @@
from .dummy_data_generator import DummyDataGenerator from .dummy_data_generator import DummyDataGenerator
from .executor import run_fwd_bwd from .executor import run_fwd, run_fwd_bwd

View File

@ -1,9 +1,9 @@
import torch import torch
def run_fwd_bwd(model, data, label, criterion, optimizer=None) -> torch.Tensor: def run_fwd(model, data, label, criterion) -> torch.Tensor:
"""run_fwd_bwd """run_fwd
run fwd and bwd for the model run fwd for the model
Args: Args:
model (torch.nn.Module): a PyTorch model model (torch.nn.Module): a PyTorch model
@ -22,6 +22,23 @@ def run_fwd_bwd(model, data, label, criterion, optimizer=None) -> torch.Tensor:
loss = model(data, label) loss = model(data, label)
loss = loss.float() loss = loss.float()
return loss
def run_fwd_bwd(model, data, label, criterion, optimizer=None) -> torch.Tensor:
"""run_fwd_bwd
run fwd and bwd for the model
Args:
model (torch.nn.Module): a PyTorch model
data (torch.Tensor): input data
label (torch.Tensor): label
criterion (Optional[Callable]): a function of criterion
Returns:
torch.Tensor: loss of fwd
"""
loss = run_fwd(model, data, label, criterion)
if optimizer: if optimizer:
optimizer.backward(loss) optimizer.backward(loss)
else: else:

View File

@ -12,7 +12,7 @@ from colossalai.utils.cuda import get_current_device
from colossalai.zero import ColoInitContext, ZeroDDP, ZeroOptimizer from colossalai.zero import ColoInitContext, ZeroDDP, ZeroOptimizer
from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration
from colossalai.zero.gemini.gemini_mgr import GeminiManager from colossalai.zero.gemini.gemini_mgr import GeminiManager
from tests.components_to_test import run_fwd_bwd from tests.components_to_test import run_fwd, run_fwd_bwd
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
from tests.test_tensor.common_utils import set_seed from tests.test_tensor.common_utils import set_seed
@ -89,10 +89,65 @@ def exam_gpt_fwd_bwd(
check_grad(model, torch_model) check_grad(model, torch_model)
@parameterize('placement_policy', ['cuda', 'cpu'])
@parameterize('keep_gather', [False, True])
@parameterize('model_name', ['gpt2', 'bert', 'albert'])
@parameterize('scatter_after_inference', [False, True])
def exam_gpt_inference(
placement_policy,
keep_gather,
model_name: str,
scatter_after_inference: bool = False,
):
init_device = get_current_device()
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
set_seed(42)
with ColoInitContext(device=init_device):
model = model_builder()
set_seed(42)
torch_model = model_builder().cuda()
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
torch_p.data.copy_(p.data)
world_size = torch.distributed.get_world_size()
config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = keep_gather
chunk_manager = ChunkManager(config_dict)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager, pin_memory=True, scatter_after_inference=scatter_after_inference)
pg = ProcessGroup()
amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=1)
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
torch_model = DDP(torch_model, device_ids=[pg.rank()], process_group=pg.dp_process_group())
set_seed(pg.dp_local_rank())
model.eval()
torch_model.eval()
for i, (input_ids, label) in enumerate(train_dataloader):
# you can only test a single fwd + bwd.
# after bwd param is grad for Gemini, due to the chunk reuse optimization.
if i > 0:
break
with torch.no_grad():
input_ids, label = input_ids.cuda(), label.cuda()
torch_loss = run_fwd(torch_model, input_ids, label, criterion)
loss = run_fwd(model, input_ids, label, criterion)
assert torch.equal(torch_loss, loss)
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):
config = {} config = {}
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
exam_gpt_fwd_bwd() exam_gpt_fwd_bwd()
exam_gpt_inference()
@pytest.mark.dist @pytest.mark.dist