Loading source
Pulling the file list, source metadata, and syntax-aware rendering for this listing.
Source from repo
Build and deploy AI applications on Azure AI Foundry using Microsoft's model catalog and AI services
Files
Skill
Size
Entrypoint
Format
Open file
Syntax-highlighted preview of this file as included in the skill package.
finetuning/scripts/check_training.py
1# /// script2# dependencies = [3# "openai>=1.0",4# "azure-identity",5# "azure-ai-projects",6# ]7# ///8"""9check_training.py — Analyze training curves, detect overfitting, list checkpoints.1011Usage:12python check_training.py --job-id ftjob-abc12313python check_training.py --job-id ftjob-abc123 --download-csv results.csv14python check_training.py --base-url https://<resource>.services.ai.azure.com/api/projects/<project>/openai/v1/ --api-key KEY --job-id ftjob-abc12315"""1617import csv18import io19import os20import sys2122try:23sys.stdout.reconfigure(encoding="utf-8")24sys.stderr.reconfigure(encoding="utf-8")25except (AttributeError, OSError):26pass # Stream not reconfigurable (older Python or non-tty); default encoding is fine27sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))28from common import HelpOnErrorParser, get_clients293031def analyze_job(client, job_id, download_csv=None):32"""Pull training results, analyze curves, detect overfitting."""33job = client.fine_tuning.jobs.retrieve(job_id)3435print(f"Job: {job.id}")36print(f" Model: {job.model}")37print(f" Status: {job.status}")38print(f" Fine-tuned model: {job.fine_tuned_model}")3940if job.hyperparameters:41hp = job.hyperparameters42print(f" Epochs: {getattr(hp, 'n_epochs', 'N/A')}")43print(f" LR multiplier: {getattr(hp, 'learning_rate_multiplier', 'N/A')}")44print(f" Batch size: {getattr(hp, 'batch_size', 'N/A')}")4546# Allow analysis while still running if result files exist47if job.status not in ("succeeded", "running"):48print(f"\n Job status is '{job.status}'. Cannot analyze curves.")49return5051if not job.result_files:52if job.status == "running":53print("\n Job is still running and no result files available yet. Check back later.")54else:55print("\n No result files available.")56return5758# Download results CSV59content = client.files.content(job.result_files[0])60csv_data = content.read()6162if download_csv:63with open(download_csv, "wb") as f:64f.write(csv_data)65print(f"\n Results CSV saved to {download_csv}")6667# Parse CSV68reader = csv.DictReader(io.StringIO(csv_data.decode("utf-8")))69rows = list(reader)7071if job.status == "running":72print(f"\n ⚡ Job still running — showing partial results ({len(rows)} steps so far)")7374# Extract validation checkpoints75val_points = []76for row in rows:77step = int(row.get("step", 0))78train_loss = float(row["train_loss"]) if row.get("train_loss", "").strip() else None79val_loss = None80for col in ["valid_loss", "full_valid_loss", "eval_loss"]:81if row.get(col, "").strip():82val_loss = float(row[col])83break8485if val_loss is not None:86val_points.append((step, val_loss, train_loss))8788if not val_points:89print("\n No validation loss data found in results CSV.")90return9192# Find best validation checkpoint93best_step, best_val, best_train = min(val_points, key=lambda x: x[1])94final_step, final_val, final_train = val_points[-1]9596print(f"\n Training Curve Analysis:")97print(f" {'Step':>6} {'Val Loss':>10} {'Train Loss':>12} {'Ratio':>8}")98print(f" {'─'*6} {'─'*10} {'─'*12} {'─'*8}")99for step, val, train in val_points:100ratio = val / train if train and train > 0 else 0101marker = " ← best" if step == best_step else ""102train_str = f"{train:12.4f}" if train is not None else " N/A"103print(f" {step:>6} {val:>10.4f} {train_str} {ratio:>8.2f}{marker}")104105print(f"\n Best val_loss: {best_val:.4f} at step {best_step}")106print(f" Final val_loss: {final_val:.4f} at step {final_step}")107108# Overfitting detection109if best_val > 0 and final_val > best_val * 1.2:110pct = (final_val - best_val) / best_val * 100111print(f"\n ⚠️ OVERFITTING DETECTED: Final val_loss is {pct:.0f}% above best.")112elif best_val == 0 and final_val > 0:113print(f"\n ⚠️ Best val_loss was 0.0; final val_loss is {final_val:.4f} — possible overfitting from a near-perfect early checkpoint.")114elif final_train and final_val / final_train > 1.5:115ratio = final_val / final_train116print(f"\n ⚠️ MODERATE OVERFITTING: val/train ratio = {ratio:.2f}")117else:118print(f"\n ✅ Training looks healthy. No significant overfitting detected.")119120# List checkpoints and recommend best deployable one121print(f"\n Checkpoints:")122available_checkpoints = []123try:124cps = client.fine_tuning.jobs.checkpoints.list(job_id)125if cps.data:126for cp in sorted(cps.data, key=lambda c: c.step_number):127vl = cp.metrics.valid_loss if cp.metrics and cp.metrics.valid_loss is not None else None128model_id = cp.fine_tuned_model_checkpoint or "N/A"129vl_str = f"{vl:.4f}" if vl is not None else "N/A"130available_checkpoints.append((cp.step_number, vl, model_id))131print(f" Step {cp.step_number}: val_loss={vl_str}, model={model_id}")132else:133print(" No checkpoints available.")134except Exception as e:135print(f" Could not retrieve checkpoints: {e}")136137# Recommend the best deployable checkpoint138if available_checkpoints and best_val > 0 and final_val > best_val * 1.2:139# Find the checkpoint with the lowest val_loss, or nearest to best_step140best_cp = None141if any(vl is not None for _, vl, _ in available_checkpoints):142# Use checkpoint with lowest val_loss143scored_cps = [(s, vl, m) for s, vl, m in available_checkpoints if vl is not None]144if scored_cps:145best_cp = min(scored_cps, key=lambda x: x[1])146else:147# No val_loss on checkpoints — pick the one nearest to (but not exceeding) best_step148earlier_cps = [(s, vl, m) for s, vl, m in available_checkpoints if s <= best_step]149if earlier_cps:150best_cp = max(earlier_cps, key=lambda x: x[0])151elif available_checkpoints:152best_cp = available_checkpoints[0]153154if best_cp:155cp_step, cp_vl, cp_model = best_cp156vl_info = f" (val_loss={cp_vl:.4f})" if cp_vl is not None else ""157print(f"\n 🎯 Recommended checkpoint: step {cp_step}{vl_info}")158print(f" Model ID: {cp_model}")159print(f" (Best val_loss was at step {best_step}, nearest deployable checkpoint is step {cp_step})")160print(f" Alternatively, retrain with fewer epochs to avoid overfitting.")161else:162print(f"\n Recommendation: Retrain with fewer epochs (best val_loss was at step {best_step}).")163164165def main():166parser = HelpOnErrorParser(description="Analyze fine-tuning training curves")167parser.add_argument("--base-url", default=os.environ.get("OPENAI_BASE_URL"),168help="Project /v1/ URL (preferred)")169parser.add_argument("--endpoint", default=os.environ.get("AZURE_OPENAI_ENDPOINT"),170help="Azure OpenAI endpoint (fallback)")171parser.add_argument("--project-endpoint", default=os.environ.get("AZURE_AI_PROJECT_ENDPOINT"),172help="Azure AI project endpoint (Foundry SDK)")173parser.add_argument("--api-key", default=os.environ.get("AZURE_OPENAI_API_KEY"))174parser.add_argument("--job-id", required=True, help="Fine-tuning job ID")175parser.add_argument("--download-csv", help="Save results CSV to this path")176args = parser.parse_args()177178client, method = get_clients(179base_url=args.base_url, azure_endpoint=args.endpoint,180project_endpoint=args.project_endpoint, api_key=args.api_key181)182analyze_job(client, args.job_id, args.download_csv)183184185if __name__ == "__main__":186main()187