Loading source
Pulling the file list, source metadata, and syntax-aware rendering for this listing.
Source from repo
Deploy, evaluate, and manage AI agents end-to-end on Microsoft Azure AI Foundry
Files
Skill
Size
Entrypoint
Format
Open file
Syntax-highlighted preview of this file as included in the skill package.
finetuning/scripts/submit_training.py
1# /// script2# dependencies = [3# "openai>=1.0",4# "requests",5# "azure-identity",6# "azure-ai-projects",7# ]8# ///9"""10submit_training.py — Submit SFT, DPO, or RFT training jobs on Azure AI Foundry.1112Handles both SDK and REST API submission (REST fallback for OSS models).13Supports /v1/ project endpoint (preferred) and Azure endpoint (fallback).1415Usage:16python submit_training.py --base-url https://<resource>.services.ai.azure.com/api/projects/<project>/openai/v1/ \17--api-key KEY --training-file training.jsonl --validation-file validation.jsonl \18--model gpt-4.1-mini --type sft --epochs 2 --lr 1.01920python submit_training.py --endpoint https://<resource>.openai.azure.com --api-key KEY \21--training-file-id file-abc123 --validation-file-id file-def456 \22--model gpt-oss-20b --type sft --epochs 2 --lr 0.5 --use-rest2324python submit_training.py --base-url <url> --api-key KEY \25--training-file-id file-abc123 --validation-file-id file-def456 \26--model o4-mini-2025-04-16 --type rft --grader-file grader.py27"""2829import json30import os31import sys323334try:35sys.stdout.reconfigure(encoding="utf-8")36sys.stderr.reconfigure(encoding="utf-8")37except (AttributeError, OSError):38pass # Stream not reconfigurable (older Python or non-tty); default encoding is fine39sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))40from common import HelpOnErrorParser, get_clients, upload_file4142import requests434445def submit_sft_sdk(client, model, train_id, val_id, epochs=2, lr=1.0, batch_size=None, suffix=None, training_type="globalStandard"):46"""Submit SFT job using the Python SDK."""47hp = {"n_epochs": epochs, "learning_rate_multiplier": lr}48if batch_size:49hp["batch_size"] = batch_size5051kwargs = dict(52model=model,53training_file=train_id,54validation_file=val_id,55method={"type": "supervised"},56hyperparameters=hp,57# Azure-specific: passed via extra_body since the OpenAI SDK has no58# top-level trainingType kwarg.59extra_body={"trainingType": training_type},60)61if suffix:62kwargs["suffix"] = suffix6364job = client.fine_tuning.jobs.create(**kwargs)65return {"id": job.id, "status": job.status, "model": model, "method": "sdk"}666768def submit_sft_rest(endpoint, api_key, model, train_id, val_id, epochs=2, lr=1.0, batch_size=None, suffix=None, training_type="globalStandard"):69"""Submit SFT job via REST API (fallback for models like gpt-oss-20b)."""70url = f"{endpoint}/openai/fine_tuning/jobs?api-version=2025-04-01-preview"71body = {72"model": model,73"training_file": train_id,74"validation_file": val_id,75"method": {"type": "supervised"},76"hyperparameters": {"n_epochs": epochs, "learning_rate_multiplier": lr},77"trainingType": training_type,78}79if batch_size:80body["hyperparameters"]["batch_size"] = batch_size81if suffix:82body["suffix"] = suffix8384resp = requests.post(url, headers={85"Content-Type": "application/json",86"api-key": api_key,87}, json=body, timeout=(10, 60))8889if resp.status_code in (200, 201):90try:91data = resp.json()92except ValueError:93raise RuntimeError(94f"REST submission returned {resp.status_code} but body was not JSON: {resp.text[:200]}"95)96if "id" not in data or "status" not in data:97raise RuntimeError(f"REST response missing 'id' or 'status' fields: {data}")98return {"id": data["id"], "status": data["status"], "model": model, "method": "rest"}99else:100try:101err_msg = resp.json().get('error', {}).get('message', 'Unknown error')102except (ValueError, KeyError):103err_msg = resp.text[:200] if resp.text else "Unknown error"104raise RuntimeError(105f"REST submission failed ({resp.status_code}): {err_msg}"106)107108109def submit_rft(client, model, train_id, val_id, grader_source):110"""Submit RFT job."""111job = client.fine_tuning.jobs.create(112model=model,113training_file=train_id,114validation_file=val_id,115method={116"type": "reinforcement",117"reinforcement": {118"grader": {119"type": "python",120"name": "custom_grader",121"source": grader_source,122},123},124},125)126return {"id": job.id, "status": job.status, "model": model, "method": "sdk-rft"}127128129def submit_dpo(client, model, train_id, val_id, epochs=2, lr=1.0, beta=0.1, suffix=None):130"""Submit DPO job."""131job = client.fine_tuning.jobs.create(132model=model,133training_file=train_id,134validation_file=val_id,135suffix=suffix or None,136method={137"type": "dpo",138"dpo": {139"hyperparameters": {140"n_epochs": epochs,141"beta": beta,142"learning_rate_multiplier": lr,143},144},145},146)147return {"id": job.id, "status": job.status, "model": model, "method": "sdk-dpo"}148149150def main():151parser = HelpOnErrorParser(description="Submit fine-tuning jobs on Azure AI Foundry")152parser.add_argument("--base-url", default=os.environ.get("OPENAI_BASE_URL"),153help="Project /v1/ URL (preferred)")154parser.add_argument("--endpoint", default=os.environ.get("AZURE_OPENAI_ENDPOINT"),155help="Azure OpenAI endpoint (fallback)")156parser.add_argument("--project-endpoint", default=os.environ.get("AZURE_AI_PROJECT_ENDPOINT"),157help="Azure AI project endpoint (Foundry SDK)")158parser.add_argument("--api-key", default=os.environ.get("AZURE_OPENAI_API_KEY"),159help="API key")160parser.add_argument("--model", required=True, help="Base model name (e.g., gpt-4.1-mini)")161parser.add_argument("--type", choices=["sft", "dpo", "rft"], default="sft",162help="Training type: sft, dpo, or rft")163164# Data files — either paths (will upload) or IDs (already uploaded)165parser.add_argument("--training-file", help="Path to training JSONL file (will upload)")166parser.add_argument("--validation-file", help="Path to validation JSONL file (will upload)")167parser.add_argument("--training-file-id", help="Already-uploaded training file ID")168parser.add_argument("--validation-file-id", help="Already-uploaded validation file ID")169170# Hyperparameters171parser.add_argument("--epochs", type=int, default=2)172parser.add_argument("--lr", type=float, default=1.0, help="Learning rate multiplier")173parser.add_argument("--batch-size", type=int, default=None)174parser.add_argument("--suffix", help="Model suffix for identification")175176# DPO-specific177parser.add_argument("--beta", type=float, default=0.1, help="DPO beta (alignment strength)")178179# RFT-specific180parser.add_argument("--grader-file", help="Path to Python grader file (for RFT)")181182# REST fallback183parser.add_argument("--use-rest", action="store_true",184help="Force REST API (needed for gpt-oss-20b and other OSS models)")185parser.add_argument("--training-type", choices=["globalStandard", "developerTier", "standard"],186default="globalStandard",187help="Azure training tier (default: globalStandard). developerTier is ~50%% off "188"globalStandard with lower quotas. OSS models (gpt-oss-20b, Ministral, "189"Llama, Qwen) only support globalStandard.")190191args = parser.parse_args()192193client, method = get_clients(194base_url=args.base_url, azure_endpoint=args.endpoint,195project_endpoint=args.project_endpoint, api_key=args.api_key196)197198# Resolve file IDs199train_id = args.training_file_id200val_id = args.validation_file_id201if args.training_file:202train_id = upload_file(client, args.training_file)203if args.validation_file:204val_id = upload_file(client, args.validation_file)205206if not train_id or not val_id:207print("Error: Provide training and validation file paths or IDs")208sys.exit(1)209210# Submit211if args.type == "rft":212if not args.grader_file:213print("Error: --grader-file required for RFT")214sys.exit(1)215with open(args.grader_file, encoding="utf-8") as f:216grader_source = f.read()217result = submit_rft(client, args.model, train_id, val_id, grader_source)218elif args.type == "dpo":219result = submit_dpo(client, args.model, train_id, val_id,220args.epochs, args.lr, args.beta, args.suffix)221elif args.use_rest:222if not args.endpoint or not args.api_key:223print("Error: --use-rest requires --endpoint and --api-key (REST does not support DefaultAzureCredential)")224sys.exit(1)225result = submit_sft_rest(args.endpoint, args.api_key, args.model,226train_id, val_id, args.epochs, args.lr, args.batch_size, args.suffix,227args.training_type)228else:229# SFT via SDK with REST fallback for OSS models230try:231result = submit_sft_sdk(client, args.model, train_id, val_id,232args.epochs, args.lr, args.batch_size, args.suffix,233args.training_type)234except Exception as e:235err_str = str(e).lower()236# Match a wider set of "use REST instead" signals than the original237# exact-string comparison: Azure changes error text periodically.238if ("trainingtype" in err_str239or "globalstandard" in err_str240or "global_standard" in err_str241or "does not support fine-tuning" in err_str):242if not args.endpoint or not args.api_key:243print(f"SDK failed for {args.model}. REST fallback requires --endpoint and --api-key.")244sys.exit(1)245print(f"SDK failed for {args.model}, falling back to REST API...")246result = submit_sft_rest(args.endpoint, args.api_key, args.model,247train_id, val_id, args.epochs, args.lr, args.batch_size, args.suffix,248args.training_type)249else:250raise251252print(f"\nJob submitted successfully:")253print(json.dumps(result, indent=2))254255# Save job info256outfile = f"ft_job_{result['id']}.json"257with open(outfile, "w", encoding="utf-8") as f:258json.dump({**result, "epochs": args.epochs, "lr": args.lr,259"batch_size": args.batch_size, "train_file": train_id,260"val_file": val_id}, f, indent=2)261print(f"Job info saved to {outfile}")262263264if __name__ == "__main__":265main()266