Loading source
Pulling the file list, source metadata, and syntax-aware rendering for this listing.
Source from repo
Build and deploy AI applications on Azure AI Foundry using Microsoft's model catalog and AI services
Files
Skill
Size
Entrypoint
Format
Open file
Syntax-highlighted preview of this file as included in the skill package.
finetuning/scripts/calibrate_grader.py
1# /// script2# dependencies = [3# "openai>=1.0",4# "azure-identity",5# "azure-ai-projects",6# ]7# ///8"""9calibrate_grader.py — Calibrate RFT grader pass_threshold before submitting a job.1011Runs the base model on your training/validation data, scores each output12with your Python grader, and recommends the optimal pass_threshold.1314Usage:15python calibrate_grader.py --base-url <url> --api-key KEY \16--model o4-mini --data train.jsonl --grader grader.py --n 301718python calibrate_grader.py --model gpt-4.1-mini --data val.jsonl \19--grader grader.py --n 20 --tools '[{"name": "search", "server_url": "https://..."}]'20"""2122import argparse23import json24import os25import random26import sys2728try:29sys.stdout.reconfigure(encoding="utf-8")30sys.stderr.reconfigure(encoding="utf-8")31except (AttributeError, OSError):32pass # Stream not reconfigurable (older Python or non-tty); default encoding is fine33import time3435sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))36from common import HelpOnErrorParser, get_clients373839def load_grader(grader_path):40"""Load and compile a Python grader file. Returns the grade() function.4142SECURITY: This executes the grader file as Python code. Only load grader43files that you wrote or reviewed — never load untrusted files from the44internet or unknown sources. The grader runs with the same permissions as45this script.46"""47grader_path = os.path.abspath(grader_path)48if not os.path.isfile(grader_path):49print(f"❌ Grader file not found: {grader_path}")50sys.exit(1)51with open(grader_path, encoding="utf-8") as f:52source = f.read()53namespace = {}54exec(compile(source, grader_path, "exec"), namespace)55if "grade" not in namespace:56print(f"❌ Grader file must define a grade(sample, item) function")57sys.exit(1)58return namespace["grade"]596061def run_model(client, model, messages, tools_schema=None, max_retries=3):62"""Run the model and return (output_text, output_tools)."""63kwargs = {"model": model, "messages": messages, "max_completion_tokens": 4096}64if tools_schema:65kwargs["tools"] = tools_schema6667for attempt in range(max_retries):68try:69resp = client.chat.completions.create(**kwargs)70msg = resp.choices[0].message71output_text = msg.content or ""72output_tools = []73if msg.tool_calls:74output_tools = [75{"type": "function", "function": {"name": tc.function.name, "arguments": tc.function.arguments}}76for tc in msg.tool_calls77]78return output_text, output_tools79except Exception as e:80if "429" in str(e) and attempt < max_retries - 1:81time.sleep(5 * (attempt + 1))82else:83return f"ERROR: {e}", []84return "ERROR: max retries", []858687def calibrate(client, model, data, grade_fn, tools_schema=None, n=30):88"""Run base model on data, score with grader, output threshold analysis."""89if not data:90print("No examples to evaluate. Check your data file.")91return9293# Sample if dataset is larger than n94if len(data) > n:95data = random.sample(data, n)9697print(f"Running {model} on {len(data)} examples...\n")9899scores = []100for i, ex in enumerate(data):101messages = ex["messages"]102user_msg = messages[-1]["content"] if messages else ""103104output_text, output_tools = run_model(client, model, messages, tools_schema)105106if output_text.startswith("ERROR:"):107print(f" [{i+1:3d}] ❌ {output_text[:60]}")108scores.append(0.0)109continue110111# Build sample dict matching what the grader expects112sample = {"output_text": output_text, "output_tools": output_tools}113114# Build item dict from all fields in the training example115item = {k: v for k, v in ex.items() if k != "messages"}116117try:118score = grade_fn(sample, item)119except Exception as e:120print(f" [{i+1:3d}] ❌ Grader error: {e}")121scores.append(0.0)122continue123124status = "✅" if score >= 0.9 else ("⚠️" if score >= 0.5 else "❌")125print(f" [{i+1:3d}] {score:.3f} {status} {user_msg[:55]}")126scores.append(score)127128time.sleep(0.5) # Rate limiting129130# Analysis131scored = [s for s in scores if s is not None]132if not scored:133print("\n❌ No examples were scored successfully. Check model access and data format.")134return135avg = sum(scored) / len(scored)136print(f"\n{'='*60}")137print(f" BASE MODEL GRADER CALIBRATION ({len(scores)} examples)")138print(f" Average score: {avg:.1%}")139print(f"{'='*60}")140141print(f"\n {'Threshold':>10} {'Pass Rate':>10} {'Fail Rate':>10} {'Signal':>20}")142print(f" {'-'*10} {'-'*10} {'-'*10} {'-'*20}")143144best_threshold = None145best_distance = float("inf")146147for threshold in [0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.85, 0.9, 0.95, 1.0]:148pass_rate = sum(1 for s in scored if s >= threshold) / len(scored)149fail_rate = 1 - pass_rate150151if 0.25 <= fail_rate <= 0.50:152signal = "✅ Good (25-50%)"153distance = abs(fail_rate - 0.35) # Ideal is ~35%154if distance < best_distance:155best_distance = distance156best_threshold = threshold157elif fail_rate < 0.10:158signal = "❌ Too easy"159elif fail_rate < 0.25:160signal = "⚠️ Weak signal"161elif fail_rate <= 0.70:162signal = "⚠️ Harsh"163else:164signal = "❌ Too hard"165166print(f" {threshold:>10.2f} {pass_rate:>9.0%} {fail_rate:>9.0%} {signal:>20}")167168if best_threshold:169print(f"\n ✅ Recommended pass_threshold: {best_threshold}")170print(f" (~{sum(1 for s in scores if s < best_threshold)/len(scores):.0%} failure rate)")171else:172print(f"\n ⚠️ No threshold in the ideal 25-50% failure range.")173print(f" Consider adjusting your grader scoring dimensions.")174175# Score distribution176print(f"\n Score distribution:")177buckets = {"0.0-0.2": 0, "0.2-0.4": 0, "0.4-0.6": 0, "0.6-0.8": 0, "0.8-0.9": 0, "0.9-1.0": 0}178for s in scores:179if s < 0.2: buckets["0.0-0.2"] += 1180elif s < 0.4: buckets["0.2-0.4"] += 1181elif s < 0.6: buckets["0.4-0.6"] += 1182elif s < 0.8: buckets["0.6-0.8"] += 1183elif s < 0.9: buckets["0.8-0.9"] += 1184else: buckets["0.9-1.0"] += 1185for bucket, count in buckets.items():186bar = "█" * count187print(f" {bucket}: {count:3d} {bar}")188189190def build_parser():191parser = HelpOnErrorParser(192description="Calibrate RFT grader pass_threshold on base model outputs",193epilog=(194"Example:\n"195" python calibrate_grader.py --model o4-mini --data train.jsonl --grader grader.py\n"196" python calibrate_grader.py --model o4-mini --data val.jsonl --grader grader.py --n 20"197),198formatter_class=argparse.RawTextHelpFormatter,199)200parser.add_argument("--base-url", default=os.environ.get("OPENAI_BASE_URL"), help="Project /v1/ endpoint URL")201parser.add_argument("--endpoint", default=os.environ.get("AZURE_OPENAI_ENDPOINT"),202help="Azure OpenAI endpoint (fallback)")203parser.add_argument("--api-key", default=os.environ.get("AZURE_OPENAI_API_KEY"), help="API key")204parser.add_argument("--project-endpoint", default=os.environ.get("AZURE_AI_PROJECT_ENDPOINT"),205help="Azure AI project endpoint")206parser.add_argument("--model", required=True, help="Base model deployment name to calibrate against")207parser.add_argument("--data", required=True, help="Path to training or validation JSONL file")208parser.add_argument("--grader", required=True, help="Path to Python grader file (must define grade(sample, item))")209parser.add_argument("--n", type=int, default=30, help="Number of examples to evaluate (default: 30)")210parser.add_argument("--tools", default=None,211help="Tool schemas as JSON array (for tool-calling models). Pass as a JSON string.")212parser.add_argument("--seed", type=int, default=42, help="Random seed for sampling (default: 42)")213return parser214215216if __name__ == "__main__":217parser = build_parser()218if len(sys.argv) == 1:219parser.print_help()220sys.exit(0)221222args = parser.parse_args()223random.seed(args.seed)224225client, method = get_clients(base_url=args.base_url, azure_endpoint=args.endpoint, project_endpoint=args.project_endpoint, api_key=args.api_key)226227# Load data228with open(args.data, encoding="utf-8") as f:229data = []230for ln, line in enumerate(f, 1):231if not line.strip():232continue233try:234data.append(json.loads(line))235except json.JSONDecodeError as e:236print(f"⚠️ Skipping malformed JSON on line {ln}: {e}")237print(f"Loaded {len(data)} examples from {args.data}")238239# Load grader240grade_fn = load_grader(args.grader)241print(f"Loaded grader from {args.grader}")242243# Parse tools if provided244tools_schema = None245if args.tools:246tools_schema = json.loads(args.tools)247248calibrate(client, args.model, data, grade_fn, tools_schema, args.n)249