Loading source
Pulling the file list, source metadata, and syntax-aware rendering for this listing.
Source from repo
Build and deploy AI applications on Azure AI Foundry using Microsoft's model catalog and AI services
Files
Skill
Size
Entrypoint
Format
Open file
Syntax-highlighted preview of this file as included in the skill package.
finetuning/scripts/submit_training.py
1# /// script2# dependencies = [3# "openai>=1.0",4# "requests",5# "azure-identity",6# "azure-ai-projects",7# ]8# ///9"""10submit_training.py — Submit SFT, DPO, or RFT training jobs on Azure AI Foundry.1112Handles both SDK and REST API submission (REST fallback for OSS models).13Supports /v1/ project endpoint (preferred) and Azure endpoint (fallback).1415Usage:16python submit_training.py --base-url https://<resource>.services.ai.azure.com/api/projects/<project>/openai/v1/ \17--api-key KEY --training-file training.jsonl --validation-file validation.jsonl \18--model gpt-4.1-mini --type sft --epochs 2 --lr 1.01920python submit_training.py --endpoint https://<resource>.openai.azure.com --api-key KEY \21--training-file-id file-abc123 --validation-file-id file-def456 \22--model gpt-oss-20b --type sft --epochs 2 --lr 0.5 --use-rest2324python submit_training.py --base-url <url> --api-key KEY \25--training-file-id file-abc123 --validation-file-id file-def456 \26--model o4-mini-2025-04-16 --type rft --grader-file grader.py27"""2829import json30import os31import sys323334try:35sys.stdout.reconfigure(encoding="utf-8")36sys.stderr.reconfigure(encoding="utf-8")37except (AttributeError, OSError):38pass # Stream not reconfigurable (older Python or non-tty); default encoding is fine39sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))40from common import HelpOnErrorParser, get_clients, upload_file4142import requests434445def submit_sft_sdk(client, model, train_id, val_id, epochs=2, lr=1.0, batch_size=None, suffix=None, training_type="globalStandard"):46"""Submit SFT job using the Python SDK."""47hp = {"n_epochs": epochs, "learning_rate_multiplier": lr}48if batch_size:49hp["batch_size"] = batch_size5051kwargs = dict(52model=model,53training_file=train_id,54validation_file=val_id,55method={"type": "supervised"},56hyperparameters=hp,57# Azure-specific: passed via extra_body since the OpenAI SDK has no58# top-level trainingType kwarg.59extra_body={"trainingType": training_type},60)61if suffix:62kwargs["suffix"] = suffix6364job = client.fine_tuning.jobs.create(**kwargs)65return {"id": job.id, "status": job.status, "model": model, "method": "sdk"}666768def submit_sft_rest(endpoint, api_key, model, train_id, val_id, epochs=2, lr=1.0, batch_size=None, suffix=None, training_type="globalStandard"):69"""Submit SFT job via REST API (fallback for models like gpt-oss-20b)."""70url = f"{endpoint}/openai/fine_tuning/jobs?api-version=2025-04-01-preview"71body = {72"model": model,73"training_file": train_id,74"validation_file": val_id,75"method": {"type": "supervised"},76"hyperparameters": {"n_epochs": epochs, "learning_rate_multiplier": lr},77"trainingType": training_type,78}79if batch_size:80body["hyperparameters"]["batch_size"] = batch_size81if suffix:82body["suffix"] = suffix8384resp = requests.post(url, headers={85"Content-Type": "application/json",86"api-key": api_key,87}, json=body, timeout=(10, 60))8889if resp.status_code in (200, 201):90try:91data = resp.json()92except ValueError:93raise RuntimeError(94f"REST submission returned {resp.status_code} but body was not JSON: {resp.text[:200]}"95)96if "id" not in data or "status" not in data:97raise RuntimeError(f"REST response missing 'id' or 'status' fields: {data}")98return {"id": data["id"], "status": data["status"], "model": model, "method": "rest"}99else:100try:101err_msg = resp.json().get('error', {}).get('message', 'Unknown error')102except (ValueError, KeyError):103err_msg = resp.text[:200] if resp.text else "Unknown error"104raise RuntimeError(105f"REST submission failed ({resp.status_code}): {err_msg}"106)107108109def submit_rft(client, model, train_id, val_id, grader_source):110"""Submit RFT job."""111job = client.fine_tuning.jobs.create(112model=model,113training_file=train_id,114validation_file=val_id,115method={116"type": "reinforcement",117"reinforcement": {118"grader": {119"type": "python",120"name": "custom_grader",121"source": grader_source,122},123},124},125)126return {"id": job.id, "status": job.status, "model": model, "method": "sdk-rft"}127128129def submit_dpo(client, model, train_id, val_id, epochs=2, lr=1.0, beta=0.1, suffix=None):130"""Submit DPO job."""131job = client.fine_tuning.jobs.create(132model=model,133training_file=train_id,134validation_file=val_id,135suffix=suffix or None,136method={137"type": "dpo",138"dpo": {139"hyperparameters": {140"n_epochs": epochs,141"beta": beta,142"learning_rate_multiplier": lr,143},144},145},146)147return {"id": job.id, "status": job.status, "model": model, "method": "sdk-dpo"}148149150def main():151parser = HelpOnErrorParser(description="Submit fine-tuning jobs on Azure AI Foundry")152parser.add_argument("--base-url", default=os.environ.get("OPENAI_BASE_URL"),153help="Project /v1/ URL (preferred)")154parser.add_argument("--endpoint", default=os.environ.get("AZURE_OPENAI_ENDPOINT"),155help="Azure OpenAI endpoint (fallback)")156parser.add_argument("--project-endpoint", default=os.environ.get("AZURE_AI_PROJECT_ENDPOINT"),157help="Azure AI project endpoint (Foundry SDK)")158parser.add_argument("--api-key", default=os.environ.get("AZURE_OPENAI_API_KEY"),159help="API key")160parser.add_argument("--model", required=True, help="Base model name (e.g., gpt-4.1-mini)")161parser.add_argument("--type", choices=["sft", "dpo", "rft"], default="sft",162help="Training type: sft, dpo, or rft")163164# Data files — either paths (will upload) or IDs (already uploaded)165parser.add_argument("--training-file", help="Path to training JSONL file (will upload)")166parser.add_argument("--validation-file", help="Path to validation JSONL file (will upload)")167parser.add_argument("--training-file-id", help="Already-uploaded training file ID")168parser.add_argument("--validation-file-id", help="Already-uploaded validation file ID")169170# Hyperparameters171parser.add_argument("--epochs", type=int, default=2)172parser.add_argument("--lr", type=float, default=1.0, help="Learning rate multiplier")173parser.add_argument("--batch-size", type=int, default=None)174parser.add_argument("--suffix", help="Model suffix for identification")175176# DPO-specific177parser.add_argument("--beta", type=float, default=0.1, help="DPO beta (alignment strength)")178179# RFT-specific180parser.add_argument("--grader-file", help="Path to Python grader file (for RFT)")181182# REST fallback183parser.add_argument("--use-rest", action="store_true",184help="Force REST API (needed for gpt-oss-20b and other OSS models)")185parser.add_argument("--training-type", choices=["globalStandard", "developerTier", "standard"],186default="globalStandard",187help="Azure training tier (default: globalStandard). developerTier is ~50%% off "188"globalStandard with lower quotas. OSS models (gpt-oss-20b, Ministral, "189"Llama, Qwen) only support globalStandard.")190191args = parser.parse_args()192193client, method = get_clients(194base_url=args.base_url, azure_endpoint=args.endpoint,195project_endpoint=args.project_endpoint, api_key=args.api_key196)197198# Resolve file IDs199train_id = args.training_file_id200val_id = args.validation_file_id201if args.training_file:202train_id = upload_file(client, args.training_file)203if args.validation_file:204val_id = upload_file(client, args.validation_file)205206if not train_id or not val_id:207print("Error: Provide training and validation file paths or IDs")208sys.exit(1)209210# Submit211if args.type == "rft":212if not args.grader_file:213print("Error: --grader-file required for RFT")214sys.exit(1)215with open(args.grader_file, encoding="utf-8") as f:216grader_source = f.read()217result = submit_rft(client, args.model, train_id, val_id, grader_source)218elif args.type == "dpo":219result = submit_dpo(client, args.model, train_id, val_id,220args.epochs, args.lr, args.beta, args.suffix)221elif args.use_rest:222if not args.endpoint or not args.api_key:223print("Error: --use-rest requires --endpoint and --api-key (REST does not support DefaultAzureCredential)")224sys.exit(1)225result = submit_sft_rest(args.endpoint, args.api_key, args.model,226train_id, val_id, args.epochs, args.lr, args.batch_size, args.suffix,227args.training_type)228else:229# SFT via SDK with REST fallback for OSS models230try:231result = submit_sft_sdk(client, args.model, train_id, val_id,232args.epochs, args.lr, args.batch_size, args.suffix,233args.training_type)234except Exception as e:235err_str = str(e).lower()236# Match a wider set of "use REST instead" signals than the original237# exact-string comparison: Azure changes error text periodically.238if ("trainingtype" in err_str239or "globalstandard" in err_str240or "global_standard" in err_str241or "does not support fine-tuning" in err_str):242if not args.endpoint or not args.api_key:243print(f"SDK failed for {args.model}. REST fallback requires --endpoint and --api-key.")244sys.exit(1)245print(f"SDK failed for {args.model}, falling back to REST API...")246result = submit_sft_rest(args.endpoint, args.api_key, args.model,247train_id, val_id, args.epochs, args.lr, args.batch_size, args.suffix,248args.training_type)249else:250raise251252print(f"\nJob submitted successfully:")253print(json.dumps(result, indent=2))254255# Save job info256outfile = f"ft_job_{result['id']}.json"257with open(outfile, "w", encoding="utf-8") as f:258json.dump({**result, "epochs": args.epochs, "lr": args.lr,259"batch_size": args.batch_size, "train_file": train_id,260"val_file": val_id}, f, indent=2)261print(f"Job info saved to {outfile}")262263264if __name__ == "__main__":265main()266