š Comparing Fine-Tuning Approaches for Chain-of-Thought (CoT)
This project explores the effectiveness of three different fine-tuning strategies for Chain-of-Thought (CoT) reasoning in large language models:
- Group Relative Policy Optimization (GRPO)
- Proximal Policy Optimization (PPO)
- Few-Shot Prompting
We apply these techniques to GPT-Neo using the GSM8K dataset, a benchmark for arithmetic and logical reasoning. A reward model is trained using DistilRoBERTa-base, providing structured feedback for reinforcement learning.
š Implementing GRPO from Scratch
GRPO (Group Relative Policy Optimization) stabilizes policy updates by ensuring a relative improvement over a reference policy. Below is the GRPO objective function used in this implementation:
import torch
import torch.nn as nn
import torch.optim as optim
def grpo_loss(policy, old_policy, advantage, epsilon=0.2, beta=0.01):
ratio = policy / old_policy
clipped_ratio = torch.clamp(ratio, 1 - epsilon, 1 + epsilon)
policy_loss = torch.min(ratio * advantage, clipped_ratio * advantage)
kl_div = torch.nn.functional.kl_div(policy.log(), old_policy, reduction="batchmean")
return -policy_loss.mean() - beta * kl_div
class GRPOAgent(nn.Module):
def __init__(self, input_dim, output_dim):
super().__init__()
self.policy = nn.Sequential(
nn.Linear(input_dim, 128), nn.ReLU(),
nn.Linear(128, output_dim), nn.Softmax(dim=-1)
)
def forward(self, x):
return self.policy(x)
š Implementing PPO from Scratch
PPO constrains policy updates using a clipped objective to prevent large policy shifts.
def ppo_loss(policy, old_policy, advantage, epsilon=0.2):
ratio = policy / old_policy
clipped_ratio = torch.clamp(ratio, 1 - epsilon, 1 + epsilon)
return -torch.min(ratio * advantage, clipped_ratio * advantage).mean()
class PPOAgent(nn.Module):
def __init__(self, input_dim, output_dim):
super().__init__()
self.policy = nn.Sequential(
nn.Linear(input_dim, 128), nn.ReLU(),
nn.Linear(128, output_dim), nn.Softmax(dim=-1)
)
def forward(self, x):
return self.policy(x)
š¤ Implementing CoT Few-Shot Prompting
Few-shot prompting uses pre-trained models with examples in the prompt instead of fine-tuning.
from transformers import AutoModelForCausalLM, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-1.3B")
model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-neo-1.3B")
def generate_cot_response(prompt, num_samples=3):
few_shot_examples = """
Q: What is 12 + 15?
A: Let's break it down. First, 12 + 10 = 22. Then, 22 + 5 = 27. Answer: 27.
Q: What is 45 - 18?
A: Let's break it down. First, 45 - 10 = 35. Then, 35 - 8 = 27. Answer: 27.
"""
full_prompt = few_shot_examples + "\nQ: " + prompt + "\nA: "
inputs = tokenizer(full_prompt, return_tensors="pt")
outputs = model.generate(**inputs, max_length=100)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
šÆ Training the Reward Model
To provide structured feedback for reinforcement learning, we train a DistilRoBERTa-base reward model on a dataset of comparisons.
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch
tokenizer = AutoTokenizer.from_pretrained("distilroberta-base")
model = AutoModelForSequenceClassification.from_pretrained("distilroberta-base", num_labels=1)
def train_reward_model(train_dataloader, optimizer, criterion, epochs=3):
model.train()
for epoch in range(epochs):
for batch in train_dataloader:
inputs = tokenizer(batch["text"], padding=True, truncation=True, return_tensors="pt")
labels = batch["labels"].float()
optimizer.zero_grad()
outputs = model(**inputs).logits.squeeze()
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
š Performance Analysis
To evaluate these approaches, we measure:
- Accuracy on GSM8K problems
- Training stability (variance in policy updates)
- Efficiency of fine-tuning (training steps to convergence)
š¢ Future Work: We aim to explore hybrid methods combining GRPO and CoT prompting.
š” Contribute: Experiment with hyperparameters and different datasets for further fine-tuning insights!
š GitHub Repository