mirror of https://github.com/hpcaitech/ColossalAI
add microbatch forwarding
parent
2aa7385c88
commit
0a3b8b050d
|
@ -11,7 +11,7 @@ from coati.distributed.consumer import BaseConsumer
|
||||||
from coati.distributed.loss import PolicyLoss
|
from coati.distributed.loss import PolicyLoss
|
||||||
from coati.distributed.reward.reward_fn import math_reward_fn
|
from coati.distributed.reward.reward_fn import math_reward_fn
|
||||||
from coati.distributed.reward.verifiable_reward import VerifiableReward
|
from coati.distributed.reward.verifiable_reward import VerifiableReward
|
||||||
from coati.distributed.utils import calc_action_log_probs
|
from coati.distributed.utils import calc_action_log_probs, get_logits_rebatched_forward
|
||||||
from coati.trainer.utils import all_reduce_mean
|
from coati.trainer.utils import all_reduce_mean
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
|
@ -68,6 +68,7 @@ class GRPOConsumer(BaseConsumer):
|
||||||
self.accum_response_length = torch.zeros(1, device=self.device)
|
self.accum_response_length = torch.zeros(1, device=self.device)
|
||||||
self.accum_count = 0
|
self.accum_count = 0
|
||||||
self.generate_config = generate_config
|
self.generate_config = generate_config
|
||||||
|
self.training_config = training_config
|
||||||
|
|
||||||
# Reference model is initialized from policy model.
|
# Reference model is initialized from policy model.
|
||||||
self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
|
self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
|
||||||
|
@ -131,14 +132,17 @@ class GRPOConsumer(BaseConsumer):
|
||||||
num_action = action_mask.shape[1]
|
num_action = action_mask.shape[1]
|
||||||
old_action_log_probs = data["action_log_probs"]
|
old_action_log_probs = data["action_log_probs"]
|
||||||
response_length = torch.sum(action_mask, dim=1).to(torch.float32)
|
response_length = torch.sum(action_mask, dim=1).to(torch.float32)
|
||||||
|
forward_batch_size = self.training_config.get("forward_micro_batch_size", data["input_ids"].size(0))
|
||||||
|
|
||||||
need_update = (step_idx + 1) % self.num_microbatches == 0
|
need_update = (step_idx + 1) % self.num_microbatches == 0
|
||||||
ctx = nullcontext() if need_update else self.booster.no_sync(self.policy_model, self.optimizer)
|
ctx = nullcontext() if need_update else self.booster.no_sync(self.policy_model, self.optimizer)
|
||||||
with ctx:
|
with ctx:
|
||||||
policy_model_logits = self.policy_model(
|
policy_model_logits = get_logits_rebatched_forward(
|
||||||
|
self.policy_model,
|
||||||
|
forward_batch_size,
|
||||||
input_ids=data["input_ids"],
|
input_ids=data["input_ids"],
|
||||||
attention_mask=data["attention_mask"],
|
attention_mask=data["attention_mask"],
|
||||||
)["logits"]
|
)
|
||||||
action_log_probs = calc_action_log_probs(
|
action_log_probs = calc_action_log_probs(
|
||||||
policy_model_logits / self.generate_config["temperature"],
|
policy_model_logits / self.generate_config["temperature"],
|
||||||
data["input_ids"],
|
data["input_ids"],
|
||||||
|
@ -147,10 +151,12 @@ class GRPOConsumer(BaseConsumer):
|
||||||
)
|
)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
reference_model_logits = self.reference_model(
|
reference_model_logits = get_logits_rebatched_forward(
|
||||||
|
self.reference_model,
|
||||||
|
forward_batch_size,
|
||||||
input_ids=data["input_ids"],
|
input_ids=data["input_ids"],
|
||||||
attention_mask=data["attention_mask"],
|
attention_mask=data["attention_mask"],
|
||||||
)["logits"]
|
)
|
||||||
reference_action_log_probs = calc_action_log_probs(
|
reference_action_log_probs = calc_action_log_probs(
|
||||||
reference_model_logits / self.generate_config["temperature"],
|
reference_model_logits / self.generate_config["temperature"],
|
||||||
data["input_ids"],
|
data["input_ids"],
|
||||||
|
|
|
@ -101,7 +101,7 @@ def launch_distributed(
|
||||||
plugin_config=plugin_config,
|
plugin_config=plugin_config,
|
||||||
microbatch_size=train_microbatch_size,
|
microbatch_size=train_microbatch_size,
|
||||||
generate_config=generate_config_consumer,
|
generate_config=generate_config_consumer,
|
||||||
training_config={"filter_range": [0.05, 9.0], "lr": 1e-6},
|
training_config={"filter_range": [0.05, 9.0], "lr": 1e-6, "forward_micro_batch_size": 4},
|
||||||
num_generations=num_generations,
|
num_generations=num_generations,
|
||||||
)
|
)
|
||||||
procs.append(consumer)
|
procs.append(consumer)
|
||||||
|
|
|
@ -113,3 +113,20 @@ def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch
|
||||||
mask_sum = mask.sum(dim=dim)
|
mask_sum = mask.sum(dim=dim)
|
||||||
mean = tensor / (mask_sum + 1e-8)
|
mean = tensor / (mask_sum + 1e-8)
|
||||||
return mean
|
return mean
|
||||||
|
|
||||||
|
|
||||||
|
def get_logits_rebatched_forward(model, batch_size, input_ids, attention_mask):
|
||||||
|
"""
|
||||||
|
Get logits from the model with rebatched forward.
|
||||||
|
Args:
|
||||||
|
model (torch.nn.Module): The model.
|
||||||
|
batch_size (int): The batch size.
|
||||||
|
input_ids (torch.Tensor): The input ids.
|
||||||
|
attention_mask (torch.Tensor): The attention mask.
|
||||||
|
"""
|
||||||
|
logits = []
|
||||||
|
for i in range(0, input_ids.size(0), batch_size):
|
||||||
|
logits.append(
|
||||||
|
model(input_ids=input_ids[i : i + batch_size], attention_mask=attention_mask[i : i + batch_size])["logits"]
|
||||||
|
)
|
||||||
|
return torch.cat(logits, dim=0)
|
||||||
|
|
Loading…
Reference in New Issue