Loading source
Pulling the file list, source metadata, and syntax-aware rendering for this listing.
Source from repo
Deploy, evaluate, and manage AI agents end-to-end on Microsoft Azure AI Foundry
Files
Skill
Size
Entrypoint
Format
Open file
Syntax-highlighted preview of this file as included in the skill package.
finetuning/scripts/calibrate_grader.py
1# /// script2# dependencies = [3# "openai>=1.0",4# "azure-identity",5# "azure-ai-projects",6# ]7# ///8"""9calibrate_grader.py — Calibrate RFT grader pass_threshold before submitting a job.1011Runs the base model on your training/validation data, scores each output12with your Python grader, and recommends the optimal pass_threshold.1314Usage:15python calibrate_grader.py --base-url <url> --api-key KEY \16--model o4-mini --data train.jsonl --grader grader.py --n 301718python calibrate_grader.py --model gpt-4.1-mini --data val.jsonl \19--grader grader.py --n 20 --tools '[{"name": "search", "server_url": "https://..."}]'20"""2122import argparse23import json24import os25import random26import sys2728try:29sys.stdout.reconfigure(encoding="utf-8")30sys.stderr.reconfigure(encoding="utf-8")31except (AttributeError, OSError):32pass # Stream not reconfigurable (older Python or non-tty); default encoding is fine33import time3435sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))36from common import HelpOnErrorParser, get_clients373839def load_grader(grader_path):40"""Load and compile a Python grader file. Returns the grade() function.4142SECURITY: This executes the grader file as Python code. Only load grader43files that you wrote or reviewed — never load untrusted files from the44internet or unknown sources. The grader runs with the same permissions as45this script.46"""47grader_path = os.path.abspath(grader_path)48if not os.path.isfile(grader_path):49print(f"❌ Grader file not found: {grader_path}")50sys.exit(1)51with open(grader_path, encoding="utf-8") as f:52source = f.read()53namespace = {}54exec(compile(source, grader_path, "exec"), namespace)55if "grade" not in namespace:56print(f"❌ Grader file must define a grade(sample, item) function")57sys.exit(1)58return namespace["grade"]596061def run_model(client, model, messages, tools_schema=None, max_retries=3):62"""Run the model and return (output_text, output_tools)."""63kwargs = {"model": model, "messages": messages, "max_completion_tokens": 4096}64if tools_schema:65kwargs["tools"] = tools_schema6667for attempt in range(max_retries):68try:69resp = client.chat.completions.create(**kwargs)70msg = resp.choices[0].message71output_text = msg.content or ""72output_tools = []73if msg.tool_calls:74output_tools = [75{"type": "function", "function": {"name": tc.function.name, "arguments": tc.function.arguments}}76for tc in msg.tool_calls77]78return output_text, output_tools79except Exception as e:80if "429" in str(e) and attempt < max_retries - 1:81time.sleep(5 * (attempt + 1))82else:83return f"ERROR: {e}", []84return "ERROR: max retries", []858687def calibrate(client, model, data, grade_fn, tools_schema=None, n=30):88"""Run base model on data, score with grader, output threshold analysis."""89if not data:90print("No examples to evaluate. Check your data file.")91return9293# Sample if dataset is larger than n94if len(data) > n:95data = random.sample(data, n)9697print(f"Running {model} on {len(data)} examples...\n")9899scores = []100for i, ex in enumerate(data):101messages = ex["messages"]102user_msg = messages[-1]["content"] if messages else ""103104output_text, output_tools = run_model(client, model, messages, tools_schema)105106if output_text.startswith("ERROR:"):107print(f" [{i+1:3d}] ❌ {output_text[:60]}")108scores.append(0.0)109continue110111# Build sample dict matching what the grader expects112sample = {"output_text": output_text, "output_tools": output_tools}113114# Build item dict from all fields in the training example115item = {k: v for k, v in ex.items() if k != "messages"}116117try:118score = grade_fn(sample, item)119except Exception as e:120print(f" [{i+1:3d}] ❌ Grader error: {e}")121scores.append(0.0)122continue123124status = "✅" if score >= 0.9 else ("⚠️" if score >= 0.5 else "❌")125print(f" [{i+1:3d}] {score:.3f} {status} {user_msg[:55]}")126scores.append(score)127128time.sleep(0.5) # Rate limiting129130# Analysis131scored = [s for s in scores if s is not None]132if not scored:133print("\n❌ No examples were scored successfully. Check model access and data format.")134return135avg = sum(scored) / len(scored)136print(f"\n{'='*60}")137print(f" BASE MODEL GRADER CALIBRATION ({len(scores)} examples)")138print(f" Average score: {avg:.1%}")139print(f"{'='*60}")140141print(f"\n {'Threshold':>10} {'Pass Rate':>10} {'Fail Rate':>10} {'Signal':>20}")142print(f" {'-'*10} {'-'*10} {'-'*10} {'-'*20}")143144best_threshold = None145best_distance = float("inf")146147for threshold in [0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.85, 0.9, 0.95, 1.0]:148pass_rate = sum(1 for s in scored if s >= threshold) / len(scored)149fail_rate = 1 - pass_rate150151if 0.25 <= fail_rate <= 0.50:152signal = "✅ Good (25-50%)"153distance = abs(fail_rate - 0.35) # Ideal is ~35%154if distance < best_distance:155best_distance = distance156best_threshold = threshold157elif fail_rate < 0.10:158signal = "❌ Too easy"159elif fail_rate < 0.25:160signal = "⚠️ Weak signal"161elif fail_rate <= 0.70:162signal = "⚠️ Harsh"163else:164signal = "❌ Too hard"165166print(f" {threshold:>10.2f} {pass_rate:>9.0%} {fail_rate:>9.0%} {signal:>20}")167168if best_threshold:169print(f"\n ✅ Recommended pass_threshold: {best_threshold}")170print(f" (~{sum(1 for s in scores if s < best_threshold)/len(scores):.0%} failure rate)")171else:172print(f"\n ⚠️ No threshold in the ideal 25-50% failure range.")173print(f" Consider adjusting your grader scoring dimensions.")174175# Score distribution176print(f"\n Score distribution:")177buckets = {"0.0-0.2": 0, "0.2-0.4": 0, "0.4-0.6": 0, "0.6-0.8": 0, "0.8-0.9": 0, "0.9-1.0": 0}178for s in scores:179if s < 0.2: buckets["0.0-0.2"] += 1180elif s < 0.4: buckets["0.2-0.4"] += 1181elif s < 0.6: buckets["0.4-0.6"] += 1182elif s < 0.8: buckets["0.6-0.8"] += 1183elif s < 0.9: buckets["0.8-0.9"] += 1184else: buckets["0.9-1.0"] += 1185for bucket, count in buckets.items():186bar = "█" * count187print(f" {bucket}: {count:3d} {bar}")188189190def build_parser():191parser = HelpOnErrorParser(192description="Calibrate RFT grader pass_threshold on base model outputs",193epilog=(194"Example:\n"195" python calibrate_grader.py --model o4-mini --data train.jsonl --grader grader.py\n"196" python calibrate_grader.py --model o4-mini --data val.jsonl --grader grader.py --n 20"197),198formatter_class=argparse.RawTextHelpFormatter,199)200parser.add_argument("--base-url", default=os.environ.get("OPENAI_BASE_URL"), help="Project /v1/ endpoint URL")201parser.add_argument("--endpoint", default=os.environ.get("AZURE_OPENAI_ENDPOINT"),202help="Azure OpenAI endpoint (fallback)")203parser.add_argument("--api-key", default=os.environ.get("AZURE_OPENAI_API_KEY"), help="API key")204parser.add_argument("--project-endpoint", default=os.environ.get("AZURE_AI_PROJECT_ENDPOINT"),205help="Azure AI project endpoint")206parser.add_argument("--model", required=True, help="Base model deployment name to calibrate against")207parser.add_argument("--data", required=True, help="Path to training or validation JSONL file")208parser.add_argument("--grader", required=True, help="Path to Python grader file (must define grade(sample, item))")209parser.add_argument("--n", type=int, default=30, help="Number of examples to evaluate (default: 30)")210parser.add_argument("--tools", default=None,211help="Tool schemas as JSON array (for tool-calling models). Pass as a JSON string.")212parser.add_argument("--seed", type=int, default=42, help="Random seed for sampling (default: 42)")213return parser214215216if __name__ == "__main__":217parser = build_parser()218if len(sys.argv) == 1:219parser.print_help()220sys.exit(0)221222args = parser.parse_args()223random.seed(args.seed)224225client, method = get_clients(base_url=args.base_url, azure_endpoint=args.endpoint, project_endpoint=args.project_endpoint, api_key=args.api_key)226227# Load data228with open(args.data, encoding="utf-8") as f:229data = []230for ln, line in enumerate(f, 1):231if not line.strip():232continue233try:234data.append(json.loads(line))235except json.JSONDecodeError as e:236print(f"⚠️ Skipping malformed JSON on line {ln}: {e}")237print(f"Loaded {len(data)} examples from {args.data}")238239# Load grader240grade_fn = load_grader(args.grader)241print(f"Loaded grader from {args.grader}")242243# Parse tools if provided244tools_schema = None245if args.tools:246tools_schema = json.loads(args.tools)247248calibrate(client, args.model, data, grade_fn, tools_schema, args.n)249