Loading source
Pulling the file list, source metadata, and syntax-aware rendering for this listing.
Source from repo
Build and deploy AI applications on Azure AI Foundry using Microsoft's model catalog and AI services
Files
Skill
Size
Entrypoint
Format
Open file
Syntax-highlighted preview of this file as included in the skill package.
finetuning/scripts/generate_distillation_data.py
1# /// script2# dependencies = [3# "openai>=1.0",4# "azure-identity",5# ]6# ///7"""8generate_distillation_data.py — Generate training data from a teacher model for distillation.910Creates a synthetic SFT dataset by:111. Generating diverse prompts from combinatorial axes (topics × formats × contexts)122. Having the teacher model produce responses133. Quality-grading each response with an LLM judge144. Filtering low-quality examples155. Splitting into train/val/test JSONL files1617Usage:18python generate_distillation_data.py \19--teacher gpt-4.1-mini \20--system-prompt "You are a formal business writer." \21--topics "earnings,risk,compliance" \22--num-prompts 300 \23--min-score 7.0 \24--output-dir ./my_dataset2526# Or with a prompts file (one prompt per line):27python generate_distillation_data.py \28--teacher gpt-4.1-mini \29--prompts-file my_prompts.txt \30--output-dir ./my_dataset31"""3233import json34import os35import random36import re37import sys3839try:40sys.stdout.reconfigure(encoding="utf-8")41sys.stderr.reconfigure(encoding="utf-8")42except (AttributeError, OSError):43pass # Stream not reconfigurable (older Python or non-tty); default encoding is fine44import time45sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))46from common import HelpOnErrorParser, get_clients, _clamp_score4748import openai495051def verify_deployment(client, model):52"""Verify a model deployment exists by sending a trivial request."""53try:54client.chat.completions.create(55model=model,56messages=[{"role": "user", "content": "Hi"}],57max_completion_tokens=1,58)59return True60except openai.NotFoundError:61return False62except Exception:63return True # other errors (rate limit, etc.) mean the deployment exists646566def generate_combinatorial_prompts(topics, formats, contexts, n):67"""Generate diverse prompts from combinatorial axes."""68prompts = []69for _ in range(n):70t = random.choice(topics)71f = random.choice(formats)72c = random.choice(contexts)73prompts.append(f"Context: {c}\n\nWrite {f} about: {t}.")74return prompts757677def teacher_generate(client, model, system_prompt, prompt, retries=3):78"""Generate a single response from the teacher."""79for attempt in range(retries):80try:81resp = client.chat.completions.create(82model=model,83messages=[84{"role": "system", "content": system_prompt},85{"role": "user", "content": prompt},86],87temperature=0.7,88max_completion_tokens=1024,89)90return resp.choices[0].message.content91except Exception as e:92if attempt >= retries - 1:93print(f" Failed after {retries} attempts: {e}")94return None95time.sleep(2 * (attempt + 1))96return None979899QUALITY_PROMPT = """Rate this AI-generated text on quality dimensions (1-10 each).100101## Text to evaluate102{output}103104## Dimensions105**Accuracy** (1-10): Is the content factually sound and coherent?106**Quality** (1-10): Is it well-written, clear, and professional?107**Task-fit** (1-10): Does it match the requested format and purpose?108109Return ONLY JSON: {{"accuracy": <int>, "quality": <int>, "task_fit": <int>}}"""110111112def grade_output(client, judge_model, output, retries=3):113for attempt in range(retries):114try:115resp = client.chat.completions.create(116model=judge_model,117messages=[{"role": "user", "content": QUALITY_PROMPT.format(output=output)}],118temperature=0.0,119max_completion_tokens=100,120)121text = (resp.choices[0].message.content or "").strip()122match = re.search(r'\{[^}]+\}', text)123if match:124scores = json.loads(match.group())125return {k: _clamp_score(v) for k, v in scores.items()}126except Exception:127if attempt < retries - 1:128time.sleep(2)129return None130131132def main():133parser = HelpOnErrorParser(description="Generate distillation training data from a teacher model")134parser.add_argument("--base-url", default=os.environ.get("OPENAI_BASE_URL"),135help="Project /v1/ URL (preferred)")136parser.add_argument("--endpoint", default=os.environ.get("AZURE_OPENAI_ENDPOINT"),137help="Azure OpenAI endpoint (fallback)")138parser.add_argument("--project-endpoint", default=os.environ.get("AZURE_AI_PROJECT_ENDPOINT"),139help="Azure AI project endpoint (Foundry SDK)")140parser.add_argument("--api-key", default=os.environ.get("AZURE_OPENAI_API_KEY"))141parser.add_argument("--teacher", required=True, help="Teacher model deployment name")142parser.add_argument("--judge", default=None, help="Judge model (default: same as teacher)")143parser.add_argument("--system-prompt", default="You are a helpful assistant.", help="System prompt for teacher")144145# Prompt generation (either combinatorial or from file)146parser.add_argument("--prompts-file", help="File with one prompt per line (skips combinatorial generation)")147parser.add_argument("--topics", help="Comma-separated topics for combinatorial prompts")148parser.add_argument("--formats", default="a concise response,a brief summary,a detailed explanation",149help="Comma-separated output formats")150parser.add_argument("--contexts", default="", help="Comma-separated context sentences")151parser.add_argument("--num-prompts", type=int, default=300, help="Number of prompts to generate")152153# Quality154parser.add_argument("--min-score", type=float, default=7.0, help="Minimum average quality score to keep")155parser.add_argument("--skip-grading", action="store_true", help="Skip quality grading (keep all)")156157# Output158parser.add_argument("--output-dir", default="./distillation_data", help="Output directory")159parser.add_argument("--train-split", type=float, default=0.8)160parser.add_argument("--val-split", type=float, default=0.1)161162args = parser.parse_args()163164client, method = get_clients(165base_url=args.base_url, azure_endpoint=args.endpoint,166project_endpoint=args.project_endpoint, api_key=args.api_key167)168judge = args.judge or args.teacher169170# Step 0: Verify deployments exist171print(f"Verifying deployment '{args.teacher}'...")172if not verify_deployment(client, args.teacher):173print(f" ERROR: Deployment '{args.teacher}' not found. Available deployments can be listed in Azure Portal.")174sys.exit(1)175print(f" ✅ Teacher deployment verified.")176177if judge != args.teacher:178print(f"Verifying judge deployment '{judge}'...")179if not verify_deployment(client, judge):180print(f" ERROR: Judge deployment '{judge}' not found.")181sys.exit(1)182print(f" ✅ Judge deployment verified.")183184# Step 1: Generate or load prompts185if args.prompts_file:186with open(args.prompts_file, encoding="utf-8") as pf:187prompts = [line.strip() for line in pf if line.strip()]188print(f"Loaded {len(prompts)} prompts from {args.prompts_file}")189else:190topics = [t.strip() for t in (args.topics or "general knowledge").split(",")]191formats = [f.strip() for f in args.formats.split(",")]192contexts = [c.strip() for c in args.contexts.split(",") if c.strip()] or [""]193prompts = generate_combinatorial_prompts(topics, formats, contexts, args.num_prompts)194print(f"Generated {len(prompts)} prompts ({len(topics)} topics × {len(formats)} formats × {len(contexts)} contexts)")195196# Step 2: Teacher generates responses197print(f"\nTeacher ({args.teacher}) generating responses...")198examples = []199for i, prompt in enumerate(prompts):200response = teacher_generate(client, args.teacher, args.system_prompt, prompt)201if response:202examples.append({"prompt": prompt, "response": response})203if (i + 1) % 25 == 0:204print(f" {i+1}/{len(prompts)} ({len(examples)} successful)")205print(f" Teacher produced {len(examples)}/{len(prompts)} responses")206207# Step 3: Quality grade and filter208if not args.skip_grading:209print(f"\nGrading with {judge}...")210for i, ex in enumerate(examples):211scores = grade_output(client, judge, ex["response"])212if scores:213ex["scores"] = scores214ex["avg_score"] = sum(scores.values()) / len(scores)215else:216ex["avg_score"] = 0217if (i + 1) % 25 == 0:218print(f" Graded {i+1}/{len(examples)}")219220filtered = [ex for ex in examples if ex["avg_score"] >= args.min_score]221avgs = [ex["avg_score"] for ex in examples if ex["avg_score"] > 0]222print(f" Passed filter (>= {args.min_score}): {len(filtered)}/{len(examples)}")223if avgs:224print(f" Scores: min={min(avgs):.1f}, max={max(avgs):.1f}, mean={sum(avgs)/len(avgs):.1f}")225else:226filtered = examples227print(f"Skipping grading — keeping all {len(filtered)} examples")228229# Step 4: Convert to SFT format and split230sft_data = [{"messages": [231{"role": "system", "content": args.system_prompt},232{"role": "user", "content": ex["prompt"]},233{"role": "assistant", "content": ex["response"]},234]} for ex in filtered]235236random.shuffle(sft_data)237n = len(sft_data)238t_end = int(n * args.train_split)239v_end = int(n * (args.train_split + args.val_split))240splits = {"train": sft_data[:t_end], "validation": sft_data[t_end:v_end], "test": sft_data[v_end:]}241242os.makedirs(args.output_dir, exist_ok=True)243for name, data in splits.items():244path = os.path.join(args.output_dir, f"{name}.jsonl")245with open(path, "w", encoding="utf-8") as f:246for ex in data:247f.write(json.dumps(ex, ensure_ascii=False) + "\n")248print(f" {name}: {len(data)} examples → {path}")249250print(f"\n✅ Done! Dataset ready in {args.output_dir}/")251252253if __name__ == "__main__":254main()255