mirror of https://github.com/hpcaitech/ColossalAI
add microbatch forwarding
parent
2aa7385c88
commit
0a3b8b050d
|
@ -11,7 +11,7 @@ from coati.distributed.consumer import BaseConsumer
|
|||
from coati.distributed.loss import PolicyLoss
|
||||
from coati.distributed.reward.reward_fn import math_reward_fn
|
||||
from coati.distributed.reward.verifiable_reward import VerifiableReward
|
||||
from coati.distributed.utils import calc_action_log_probs
|
||||
from coati.distributed.utils import calc_action_log_probs, get_logits_rebatched_forward
|
||||
from coati.trainer.utils import all_reduce_mean
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
|
@ -68,6 +68,7 @@ class GRPOConsumer(BaseConsumer):
|
|||
self.accum_response_length = torch.zeros(1, device=self.device)
|
||||
self.accum_count = 0
|
||||
self.generate_config = generate_config
|
||||
self.training_config = training_config
|
||||
|
||||
# Reference model is initialized from policy model.
|
||||
self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
|
||||
|
@ -131,14 +132,17 @@ class GRPOConsumer(BaseConsumer):
|
|||
num_action = action_mask.shape[1]
|
||||
old_action_log_probs = data["action_log_probs"]
|
||||
response_length = torch.sum(action_mask, dim=1).to(torch.float32)
|
||||
forward_batch_size = self.training_config.get("forward_micro_batch_size", data["input_ids"].size(0))
|
||||
|
||||
need_update = (step_idx + 1) % self.num_microbatches == 0
|
||||
ctx = nullcontext() if need_update else self.booster.no_sync(self.policy_model, self.optimizer)
|
||||
with ctx:
|
||||
policy_model_logits = self.policy_model(
|
||||
policy_model_logits = get_logits_rebatched_forward(
|
||||
self.policy_model,
|
||||
forward_batch_size,
|
||||
input_ids=data["input_ids"],
|
||||
attention_mask=data["attention_mask"],
|
||||
)["logits"]
|
||||
)
|
||||
action_log_probs = calc_action_log_probs(
|
||||
policy_model_logits / self.generate_config["temperature"],
|
||||
data["input_ids"],
|
||||
|
@ -147,10 +151,12 @@ class GRPOConsumer(BaseConsumer):
|
|||
)
|
||||
|
||||
with torch.no_grad():
|
||||
reference_model_logits = self.reference_model(
|
||||
reference_model_logits = get_logits_rebatched_forward(
|
||||
self.reference_model,
|
||||
forward_batch_size,
|
||||
input_ids=data["input_ids"],
|
||||
attention_mask=data["attention_mask"],
|
||||
)["logits"]
|
||||
)
|
||||
reference_action_log_probs = calc_action_log_probs(
|
||||
reference_model_logits / self.generate_config["temperature"],
|
||||
data["input_ids"],
|
||||
|
|
|
@ -101,7 +101,7 @@ def launch_distributed(
|
|||
plugin_config=plugin_config,
|
||||
microbatch_size=train_microbatch_size,
|
||||
generate_config=generate_config_consumer,
|
||||
training_config={"filter_range": [0.05, 9.0], "lr": 1e-6},
|
||||
training_config={"filter_range": [0.05, 9.0], "lr": 1e-6, "forward_micro_batch_size": 4},
|
||||
num_generations=num_generations,
|
||||
)
|
||||
procs.append(consumer)
|
||||
|
|
|
@ -113,3 +113,20 @@ def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch
|
|||
mask_sum = mask.sum(dim=dim)
|
||||
mean = tensor / (mask_sum + 1e-8)
|
||||
return mean
|
||||
|
||||
|
||||
def get_logits_rebatched_forward(model, batch_size, input_ids, attention_mask):
|
||||
"""
|
||||
Get logits from the model with rebatched forward.
|
||||
Args:
|
||||
model (torch.nn.Module): The model.
|
||||
batch_size (int): The batch size.
|
||||
input_ids (torch.Tensor): The input ids.
|
||||
attention_mask (torch.Tensor): The attention mask.
|
||||
"""
|
||||
logits = []
|
||||
for i in range(0, input_ids.size(0), batch_size):
|
||||
logits.append(
|
||||
model(input_ids=input_ids[i : i + batch_size], attention_mask=attention_mask[i : i + batch_size])["logits"]
|
||||
)
|
||||
return torch.cat(logits, dim=0)
|
||||
|
|
Loading…
Reference in New Issue