add microbatch forwarding

grpo-latest-dev
YeAnbang 2025-03-21 18:30:44 +08:00
parent 2aa7385c88
commit 0a3b8b050d
3 changed files with 29 additions and 6 deletions

View File

@ -11,7 +11,7 @@ from coati.distributed.consumer import BaseConsumer
from coati.distributed.loss import PolicyLoss from coati.distributed.loss import PolicyLoss
from coati.distributed.reward.reward_fn import math_reward_fn from coati.distributed.reward.reward_fn import math_reward_fn
from coati.distributed.reward.verifiable_reward import VerifiableReward from coati.distributed.reward.verifiable_reward import VerifiableReward
from coati.distributed.utils import calc_action_log_probs from coati.distributed.utils import calc_action_log_probs, get_logits_rebatched_forward
from coati.trainer.utils import all_reduce_mean from coati.trainer.utils import all_reduce_mean
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
@ -68,6 +68,7 @@ class GRPOConsumer(BaseConsumer):
self.accum_response_length = torch.zeros(1, device=self.device) self.accum_response_length = torch.zeros(1, device=self.device)
self.accum_count = 0 self.accum_count = 0
self.generate_config = generate_config self.generate_config = generate_config
self.training_config = training_config
# Reference model is initialized from policy model. # Reference model is initialized from policy model.
self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config) self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
@ -131,14 +132,17 @@ class GRPOConsumer(BaseConsumer):
num_action = action_mask.shape[1] num_action = action_mask.shape[1]
old_action_log_probs = data["action_log_probs"] old_action_log_probs = data["action_log_probs"]
response_length = torch.sum(action_mask, dim=1).to(torch.float32) response_length = torch.sum(action_mask, dim=1).to(torch.float32)
forward_batch_size = self.training_config.get("forward_micro_batch_size", data["input_ids"].size(0))
need_update = (step_idx + 1) % self.num_microbatches == 0 need_update = (step_idx + 1) % self.num_microbatches == 0
ctx = nullcontext() if need_update else self.booster.no_sync(self.policy_model, self.optimizer) ctx = nullcontext() if need_update else self.booster.no_sync(self.policy_model, self.optimizer)
with ctx: with ctx:
policy_model_logits = self.policy_model( policy_model_logits = get_logits_rebatched_forward(
self.policy_model,
forward_batch_size,
input_ids=data["input_ids"], input_ids=data["input_ids"],
attention_mask=data["attention_mask"], attention_mask=data["attention_mask"],
)["logits"] )
action_log_probs = calc_action_log_probs( action_log_probs = calc_action_log_probs(
policy_model_logits / self.generate_config["temperature"], policy_model_logits / self.generate_config["temperature"],
data["input_ids"], data["input_ids"],
@ -147,10 +151,12 @@ class GRPOConsumer(BaseConsumer):
) )
with torch.no_grad(): with torch.no_grad():
reference_model_logits = self.reference_model( reference_model_logits = get_logits_rebatched_forward(
self.reference_model,
forward_batch_size,
input_ids=data["input_ids"], input_ids=data["input_ids"],
attention_mask=data["attention_mask"], attention_mask=data["attention_mask"],
)["logits"] )
reference_action_log_probs = calc_action_log_probs( reference_action_log_probs = calc_action_log_probs(
reference_model_logits / self.generate_config["temperature"], reference_model_logits / self.generate_config["temperature"],
data["input_ids"], data["input_ids"],

View File

@ -101,7 +101,7 @@ def launch_distributed(
plugin_config=plugin_config, plugin_config=plugin_config,
microbatch_size=train_microbatch_size, microbatch_size=train_microbatch_size,
generate_config=generate_config_consumer, generate_config=generate_config_consumer,
training_config={"filter_range": [0.05, 9.0], "lr": 1e-6}, training_config={"filter_range": [0.05, 9.0], "lr": 1e-6, "forward_micro_batch_size": 4},
num_generations=num_generations, num_generations=num_generations,
) )
procs.append(consumer) procs.append(consumer)

View File

@ -113,3 +113,20 @@ def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch
mask_sum = mask.sum(dim=dim) mask_sum = mask.sum(dim=dim)
mean = tensor / (mask_sum + 1e-8) mean = tensor / (mask_sum + 1e-8)
return mean return mean
def get_logits_rebatched_forward(model, batch_size, input_ids, attention_mask):
"""
Get logits from the model with rebatched forward.
Args:
model (torch.nn.Module): The model.
batch_size (int): The batch size.
input_ids (torch.Tensor): The input ids.
attention_mask (torch.Tensor): The attention mask.
"""
logits = []
for i in range(0, input_ids.size(0), batch_size):
logits.append(
model(input_ids=input_ids[i : i + batch_size], attention_mask=attention_mask[i : i + batch_size])["logits"]
)
return torch.cat(logits, dim=0)