0

Fine-Tuning CoT: GRPO vs. PPO vs. Few-Shot Prompting

Implementation of GRPO, PPO, and CoT few-shot prompting from scratch to compare fine-tuning approaches for Chain-of-Thought reasoning in GPT-Neo using the GSM8K dataset.

GRPO Objective Function

šŸ” Comparing Fine-Tuning Approaches for Chain-of-Thought (CoT)

This project explores the effectiveness of three different fine-tuning strategies for Chain-of-Thought (CoT) reasoning in large language models:

  • Group Relative Policy Optimization (GRPO)
  • Proximal Policy Optimization (PPO)
  • Few-Shot Prompting

We apply these techniques to GPT-Neo using the GSM8K dataset, a benchmark for arithmetic and logical reasoning. A reward model is trained using DistilRoBERTa-base, providing structured feedback for reinforcement learning.


šŸš€ Implementing GRPO from Scratch

GRPO (Group Relative Policy Optimization) stabilizes policy updates by ensuring a relative improvement over a reference policy. Below is the GRPO objective function used in this implementation:

import torch
import torch.nn as nn
import torch.optim as optim
 
def grpo_loss(policy, old_policy, advantage, epsilon=0.2, beta=0.01):
    ratio = policy / old_policy
    clipped_ratio = torch.clamp(ratio, 1 - epsilon, 1 + epsilon)
    policy_loss = torch.min(ratio * advantage, clipped_ratio * advantage)
    kl_div = torch.nn.functional.kl_div(policy.log(), old_policy, reduction="batchmean")
    return -policy_loss.mean() - beta * kl_div
 
class GRPOAgent(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.policy = nn.Sequential(
            nn.Linear(input_dim, 128), nn.ReLU(),
            nn.Linear(128, output_dim), nn.Softmax(dim=-1)
        )
    
    def forward(self, x):
        return self.policy(x)

šŸ”„ Implementing PPO from Scratch

PPO constrains policy updates using a clipped objective to prevent large policy shifts.

def ppo_loss(policy, old_policy, advantage, epsilon=0.2):
    ratio = policy / old_policy
    clipped_ratio = torch.clamp(ratio, 1 - epsilon, 1 + epsilon)
    return -torch.min(ratio * advantage, clipped_ratio * advantage).mean()
 
class PPOAgent(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.policy = nn.Sequential(
            nn.Linear(input_dim, 128), nn.ReLU(),
            nn.Linear(128, output_dim), nn.Softmax(dim=-1)
        )
 
    def forward(self, x):
        return self.policy(x)

šŸ¤– Implementing CoT Few-Shot Prompting

Few-shot prompting uses pre-trained models with examples in the prompt instead of fine-tuning.

from transformers import AutoModelForCausalLM, AutoTokenizer
 
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-1.3B")
model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-neo-1.3B")
 
def generate_cot_response(prompt, num_samples=3):
    few_shot_examples = """
    Q: What is 12 + 15?
    A: Let's break it down. First, 12 + 10 = 22. Then, 22 + 5 = 27. Answer: 27.
    Q: What is 45 - 18?
    A: Let's break it down. First, 45 - 10 = 35. Then, 35 - 8 = 27. Answer: 27.
    """
    full_prompt = few_shot_examples + "\nQ: " + prompt + "\nA: "
    inputs = tokenizer(full_prompt, return_tensors="pt")
    outputs = model.generate(**inputs, max_length=100)
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

šŸŽÆ Training the Reward Model

To provide structured feedback for reinforcement learning, we train a DistilRoBERTa-base reward model on a dataset of comparisons.

from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch
 
tokenizer = AutoTokenizer.from_pretrained("distilroberta-base")
model = AutoModelForSequenceClassification.from_pretrained("distilroberta-base", num_labels=1)
 
def train_reward_model(train_dataloader, optimizer, criterion, epochs=3):
    model.train()
    for epoch in range(epochs):
        for batch in train_dataloader:
            inputs = tokenizer(batch["text"], padding=True, truncation=True, return_tensors="pt")
            labels = batch["labels"].float()
            optimizer.zero_grad()
            outputs = model(**inputs).logits.squeeze()
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

šŸ“Š Performance Analysis

To evaluate these approaches, we measure:

  • Accuracy on GSM8K problems
  • Training stability (variance in policy updates)
  • Efficiency of fine-tuning (training steps to convergence)

šŸ“¢ Future Work: We aim to explore hybrid methods combining GRPO and CoT prompting.



šŸ’” Contribute: Experiment with hyperparameters and different datasets for further fine-tuning insights!

šŸ”— GitHub Repository