Loading source
Pulling the file list, source metadata, and syntax-aware rendering for this listing.
Source from repo
Deploy, evaluate, and manage AI agents end-to-end on Microsoft Azure AI Foundry
Files
Skill
Size
Entrypoint
Format
Open file
Syntax-highlighted preview of this file as included in the skill package.
finetuning/scripts/generate_distillation_data.py
1# /// script2# dependencies = [3# "openai>=1.0",4# "azure-identity",5# ]6# ///7"""8generate_distillation_data.py — Generate training data from a teacher model for distillation.910Creates a synthetic SFT dataset by:111. Generating diverse prompts from combinatorial axes (topics × formats × contexts)122. Having the teacher model produce responses133. Quality-grading each response with an LLM judge144. Filtering low-quality examples155. Splitting into train/val/test JSONL files1617Usage:18python generate_distillation_data.py \19--teacher gpt-4.1-mini \20--system-prompt "You are a formal business writer." \21--topics "earnings,risk,compliance" \22--num-prompts 300 \23--min-score 7.0 \24--output-dir ./my_dataset2526# Or with a prompts file (one prompt per line):27python generate_distillation_data.py \28--teacher gpt-4.1-mini \29--prompts-file my_prompts.txt \30--output-dir ./my_dataset31"""3233import json34import os35import random36import re37import sys3839try:40sys.stdout.reconfigure(encoding="utf-8")41sys.stderr.reconfigure(encoding="utf-8")42except (AttributeError, OSError):43pass # Stream not reconfigurable (older Python or non-tty); default encoding is fine44import time45sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))46from common import HelpOnErrorParser, get_clients, _clamp_score4748import openai495051def verify_deployment(client, model):52"""Verify a model deployment exists by sending a trivial request."""53try:54client.chat.completions.create(55model=model,56messages=[{"role": "user", "content": "Hi"}],57max_completion_tokens=1,58)59return True60except openai.NotFoundError:61return False62except Exception:63return True # other errors (rate limit, etc.) mean the deployment exists646566def generate_combinatorial_prompts(topics, formats, contexts, n):67"""Generate diverse prompts from combinatorial axes."""68prompts = []69for _ in range(n):70t = random.choice(topics)71f = random.choice(formats)72c = random.choice(contexts)73prompts.append(f"Context: {c}\n\nWrite {f} about: {t}.")74return prompts757677def teacher_generate(client, model, system_prompt, prompt, retries=3):78"""Generate a single response from the teacher."""79for attempt in range(retries):80try:81resp = client.chat.completions.create(82model=model,83messages=[84{"role": "system", "content": system_prompt},85{"role": "user", "content": prompt},86],87temperature=0.7,88max_completion_tokens=1024,89)90return resp.choices[0].message.content91except Exception as e:92if attempt >= retries - 1:93print(f" Failed after {retries} attempts: {e}")94return None95time.sleep(2 * (attempt + 1))96return None979899QUALITY_PROMPT = """Rate this AI-generated text on quality dimensions (1-10 each).100101## Text to evaluate102{output}103104## Dimensions105**Accuracy** (1-10): Is the content factually sound and coherent?106**Quality** (1-10): Is it well-written, clear, and professional?107**Task-fit** (1-10): Does it match the requested format and purpose?108109Return ONLY JSON: {{"accuracy": <int>, "quality": <int>, "task_fit": <int>}}"""110111112def grade_output(client, judge_model, output, retries=3):113for attempt in range(retries):114try:115resp = client.chat.completions.create(116model=judge_model,117messages=[{"role": "user", "content": QUALITY_PROMPT.format(output=output)}],118temperature=0.0,119max_completion_tokens=100,120)121text = (resp.choices[0].message.content or "").strip()122match = re.search(r'\{[^}]+\}', text)123if match:124scores = json.loads(match.group())125return {k: _clamp_score(v) for k, v in scores.items()}126except Exception:127if attempt < retries - 1:128time.sleep(2)129return None130131132def main():133parser = HelpOnErrorParser(description="Generate distillation training data from a teacher model")134parser.add_argument("--base-url", default=os.environ.get("OPENAI_BASE_URL"),135help="Project /v1/ URL (preferred)")136parser.add_argument("--endpoint", default=os.environ.get("AZURE_OPENAI_ENDPOINT"),137help="Azure OpenAI endpoint (fallback)")138parser.add_argument("--project-endpoint", default=os.environ.get("AZURE_AI_PROJECT_ENDPOINT"),139help="Azure AI project endpoint (Foundry SDK)")140parser.add_argument("--api-key", default=os.environ.get("AZURE_OPENAI_API_KEY"))141parser.add_argument("--teacher", required=True, help="Teacher model deployment name")142parser.add_argument("--judge", default=None, help="Judge model (default: same as teacher)")143parser.add_argument("--system-prompt", default="You are a helpful assistant.", help="System prompt for teacher")144145# Prompt generation (either combinatorial or from file)146parser.add_argument("--prompts-file", help="File with one prompt per line (skips combinatorial generation)")147parser.add_argument("--topics", help="Comma-separated topics for combinatorial prompts")148parser.add_argument("--formats", default="a concise response,a brief summary,a detailed explanation",149help="Comma-separated output formats")150parser.add_argument("--contexts", default="", help="Comma-separated context sentences")151parser.add_argument("--num-prompts", type=int, default=300, help="Number of prompts to generate")152153# Quality154parser.add_argument("--min-score", type=float, default=7.0, help="Minimum average quality score to keep")155parser.add_argument("--skip-grading", action="store_true", help="Skip quality grading (keep all)")156157# Output158parser.add_argument("--output-dir", default="./distillation_data", help="Output directory")159parser.add_argument("--train-split", type=float, default=0.8)160parser.add_argument("--val-split", type=float, default=0.1)161162args = parser.parse_args()163164client, method = get_clients(165base_url=args.base_url, azure_endpoint=args.endpoint,166project_endpoint=args.project_endpoint, api_key=args.api_key167)168judge = args.judge or args.teacher169170# Step 0: Verify deployments exist171print(f"Verifying deployment '{args.teacher}'...")172if not verify_deployment(client, args.teacher):173print(f" ERROR: Deployment '{args.teacher}' not found. Available deployments can be listed in Azure Portal.")174sys.exit(1)175print(f" ✅ Teacher deployment verified.")176177if judge != args.teacher:178print(f"Verifying judge deployment '{judge}'...")179if not verify_deployment(client, judge):180print(f" ERROR: Judge deployment '{judge}' not found.")181sys.exit(1)182print(f" ✅ Judge deployment verified.")183184# Step 1: Generate or load prompts185if args.prompts_file:186with open(args.prompts_file, encoding="utf-8") as pf:187prompts = [line.strip() for line in pf if line.strip()]188print(f"Loaded {len(prompts)} prompts from {args.prompts_file}")189else:190topics = [t.strip() for t in (args.topics or "general knowledge").split(",")]191formats = [f.strip() for f in args.formats.split(",")]192contexts = [c.strip() for c in args.contexts.split(",") if c.strip()] or [""]193prompts = generate_combinatorial_prompts(topics, formats, contexts, args.num_prompts)194print(f"Generated {len(prompts)} prompts ({len(topics)} topics × {len(formats)} formats × {len(contexts)} contexts)")195196# Step 2: Teacher generates responses197print(f"\nTeacher ({args.teacher}) generating responses...")198examples = []199for i, prompt in enumerate(prompts):200response = teacher_generate(client, args.teacher, args.system_prompt, prompt)201if response:202examples.append({"prompt": prompt, "response": response})203if (i + 1) % 25 == 0:204print(f" {i+1}/{len(prompts)} ({len(examples)} successful)")205print(f" Teacher produced {len(examples)}/{len(prompts)} responses")206207# Step 3: Quality grade and filter208if not args.skip_grading:209print(f"\nGrading with {judge}...")210for i, ex in enumerate(examples):211scores = grade_output(client, judge, ex["response"])212if scores:213ex["scores"] = scores214ex["avg_score"] = sum(scores.values()) / len(scores)215else:216ex["avg_score"] = 0217if (i + 1) % 25 == 0:218print(f" Graded {i+1}/{len(examples)}")219220filtered = [ex for ex in examples if ex["avg_score"] >= args.min_score]221avgs = [ex["avg_score"] for ex in examples if ex["avg_score"] > 0]222print(f" Passed filter (>= {args.min_score}): {len(filtered)}/{len(examples)}")223if avgs:224print(f" Scores: min={min(avgs):.1f}, max={max(avgs):.1f}, mean={sum(avgs)/len(avgs):.1f}")225else:226filtered = examples227print(f"Skipping grading — keeping all {len(filtered)} examples")228229# Step 4: Convert to SFT format and split230sft_data = [{"messages": [231{"role": "system", "content": args.system_prompt},232{"role": "user", "content": ex["prompt"]},233{"role": "assistant", "content": ex["response"]},234]} for ex in filtered]235236random.shuffle(sft_data)237n = len(sft_data)238t_end = int(n * args.train_split)239v_end = int(n * (args.train_split + args.val_split))240splits = {"train": sft_data[:t_end], "validation": sft_data[t_end:v_end], "test": sft_data[v_end:]}241242os.makedirs(args.output_dir, exist_ok=True)243for name, data in splits.items():244path = os.path.join(args.output_dir, f"{name}.jsonl")245with open(path, "w", encoding="utf-8") as f:246for ex in data:247f.write(json.dumps(ex, ensure_ascii=False) + "\n")248print(f" {name}: {len(data)} examples → {path}")249250print(f"\n✅ Done! Dataset ready in {args.output_dir}/")251252253if __name__ == "__main__":254main()255