Loading source
Pulling the file list, source metadata, and syntax-aware rendering for this listing.
Source from repo
Deploy, evaluate, and manage AI agents end-to-end on Microsoft Azure AI Foundry
Files
Skill
Size
Entrypoint
Format
Open file
Syntax-highlighted preview of this file as included in the skill package.
finetuning/scripts/check_training.py
1# /// script2# dependencies = [3# "openai>=1.0",4# "azure-identity",5# "azure-ai-projects",6# ]7# ///8"""9check_training.py — Analyze training curves, detect overfitting, list checkpoints.1011Usage:12python check_training.py --job-id ftjob-abc12313python check_training.py --job-id ftjob-abc123 --download-csv results.csv14python check_training.py --base-url https://<resource>.services.ai.azure.com/api/projects/<project>/openai/v1/ --api-key KEY --job-id ftjob-abc12315"""1617import csv18import io19import os20import sys2122try:23sys.stdout.reconfigure(encoding="utf-8")24sys.stderr.reconfigure(encoding="utf-8")25except (AttributeError, OSError):26pass # Stream not reconfigurable (older Python or non-tty); default encoding is fine27sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))28from common import HelpOnErrorParser, get_clients293031def analyze_job(client, job_id, download_csv=None):32"""Pull training results, analyze curves, detect overfitting."""33job = client.fine_tuning.jobs.retrieve(job_id)3435print(f"Job: {job.id}")36print(f" Model: {job.model}")37print(f" Status: {job.status}")38print(f" Fine-tuned model: {job.fine_tuned_model}")3940if job.hyperparameters:41hp = job.hyperparameters42print(f" Epochs: {getattr(hp, 'n_epochs', 'N/A')}")43print(f" LR multiplier: {getattr(hp, 'learning_rate_multiplier', 'N/A')}")44print(f" Batch size: {getattr(hp, 'batch_size', 'N/A')}")4546# Allow analysis while still running if result files exist47if job.status not in ("succeeded", "running"):48print(f"\n Job status is '{job.status}'. Cannot analyze curves.")49return5051if not job.result_files:52if job.status == "running":53print("\n Job is still running and no result files available yet. Check back later.")54else:55print("\n No result files available.")56return5758# Download results CSV59content = client.files.content(job.result_files[0])60csv_data = content.read()6162if download_csv:63with open(download_csv, "wb") as f:64f.write(csv_data)65print(f"\n Results CSV saved to {download_csv}")6667# Parse CSV68reader = csv.DictReader(io.StringIO(csv_data.decode("utf-8")))69rows = list(reader)7071if job.status == "running":72print(f"\n ⚡ Job still running — showing partial results ({len(rows)} steps so far)")7374# Extract validation checkpoints75val_points = []76for row in rows:77step = int(row.get("step", 0))78train_loss = float(row["train_loss"]) if row.get("train_loss", "").strip() else None79val_loss = None80for col in ["valid_loss", "full_valid_loss", "eval_loss"]:81if row.get(col, "").strip():82val_loss = float(row[col])83break8485if val_loss is not None:86val_points.append((step, val_loss, train_loss))8788if not val_points:89print("\n No validation loss data found in results CSV.")90return9192# Find best validation checkpoint93best_step, best_val, best_train = min(val_points, key=lambda x: x[1])94final_step, final_val, final_train = val_points[-1]9596print(f"\n Training Curve Analysis:")97print(f" {'Step':>6} {'Val Loss':>10} {'Train Loss':>12} {'Ratio':>8}")98print(f" {'─'*6} {'─'*10} {'─'*12} {'─'*8}")99for step, val, train in val_points:100ratio = val / train if train and train > 0 else 0101marker = " ← best" if step == best_step else ""102train_str = f"{train:12.4f}" if train is not None else " N/A"103print(f" {step:>6} {val:>10.4f} {train_str} {ratio:>8.2f}{marker}")104105print(f"\n Best val_loss: {best_val:.4f} at step {best_step}")106print(f" Final val_loss: {final_val:.4f} at step {final_step}")107108# Overfitting detection109if best_val > 0 and final_val > best_val * 1.2:110pct = (final_val - best_val) / best_val * 100111print(f"\n ⚠️ OVERFITTING DETECTED: Final val_loss is {pct:.0f}% above best.")112elif best_val == 0 and final_val > 0:113print(f"\n ⚠️ Best val_loss was 0.0; final val_loss is {final_val:.4f} — possible overfitting from a near-perfect early checkpoint.")114elif final_train and final_val / final_train > 1.5:115ratio = final_val / final_train116print(f"\n ⚠️ MODERATE OVERFITTING: val/train ratio = {ratio:.2f}")117else:118print(f"\n ✅ Training looks healthy. No significant overfitting detected.")119120# List checkpoints and recommend best deployable one121print(f"\n Checkpoints:")122available_checkpoints = []123try:124cps = client.fine_tuning.jobs.checkpoints.list(job_id)125if cps.data:126for cp in sorted(cps.data, key=lambda c: c.step_number):127vl = cp.metrics.valid_loss if cp.metrics and cp.metrics.valid_loss is not None else None128model_id = cp.fine_tuned_model_checkpoint or "N/A"129vl_str = f"{vl:.4f}" if vl is not None else "N/A"130available_checkpoints.append((cp.step_number, vl, model_id))131print(f" Step {cp.step_number}: val_loss={vl_str}, model={model_id}")132else:133print(" No checkpoints available.")134except Exception as e:135print(f" Could not retrieve checkpoints: {e}")136137# Recommend the best deployable checkpoint138if available_checkpoints and best_val > 0 and final_val > best_val * 1.2:139# Find the checkpoint with the lowest val_loss, or nearest to best_step140best_cp = None141if any(vl is not None for _, vl, _ in available_checkpoints):142# Use checkpoint with lowest val_loss143scored_cps = [(s, vl, m) for s, vl, m in available_checkpoints if vl is not None]144if scored_cps:145best_cp = min(scored_cps, key=lambda x: x[1])146else:147# No val_loss on checkpoints — pick the one nearest to (but not exceeding) best_step148earlier_cps = [(s, vl, m) for s, vl, m in available_checkpoints if s <= best_step]149if earlier_cps:150best_cp = max(earlier_cps, key=lambda x: x[0])151elif available_checkpoints:152best_cp = available_checkpoints[0]153154if best_cp:155cp_step, cp_vl, cp_model = best_cp156vl_info = f" (val_loss={cp_vl:.4f})" if cp_vl is not None else ""157print(f"\n 🎯 Recommended checkpoint: step {cp_step}{vl_info}")158print(f" Model ID: {cp_model}")159print(f" (Best val_loss was at step {best_step}, nearest deployable checkpoint is step {cp_step})")160print(f" Alternatively, retrain with fewer epochs to avoid overfitting.")161else:162print(f"\n Recommendation: Retrain with fewer epochs (best val_loss was at step {best_step}).")163164165def main():166parser = HelpOnErrorParser(description="Analyze fine-tuning training curves")167parser.add_argument("--base-url", default=os.environ.get("OPENAI_BASE_URL"),168help="Project /v1/ URL (preferred)")169parser.add_argument("--endpoint", default=os.environ.get("AZURE_OPENAI_ENDPOINT"),170help="Azure OpenAI endpoint (fallback)")171parser.add_argument("--project-endpoint", default=os.environ.get("AZURE_AI_PROJECT_ENDPOINT"),172help="Azure AI project endpoint (Foundry SDK)")173parser.add_argument("--api-key", default=os.environ.get("AZURE_OPENAI_API_KEY"))174parser.add_argument("--job-id", required=True, help="Fine-tuning job ID")175parser.add_argument("--download-csv", help="Save results CSV to this path")176args = parser.parse_args()177178client, method = get_clients(179base_url=args.base_url, azure_endpoint=args.endpoint,180project_endpoint=args.project_endpoint, api_key=args.api_key181)182analyze_job(client, args.job_id, args.download_csv)183184185if __name__ == "__main__":186main()187