Loading source
Pulling the file list, source metadata, and syntax-aware rendering for this listing.
Source from repo
Build and deploy AI applications on Azure AI Foundry using Microsoft's model catalog and AI services
Files
Skill
Size
Entrypoint
Format
Open file
Syntax-highlighted preview of this file as included in the skill package.
finetuning/scripts/convert_dataset.py
1# /// script2# dependencies = [3# "openai>=1.0",4# ]5# ///6"""7convert_dataset.py — Convert between SFT, DPO, and RFT dataset formats.89Usage:10# Parquet/CSV to SFT JSONL11python convert_dataset.py --input data.parquet --output train.jsonl --format sft \12--user-column prompt --assistant-column response --system-prompt "You are helpful."1314# SFT JSONL to DPO (generates rejected via base model)15python convert_dataset.py --input train.jsonl --output dpo.jsonl --format dpo \16--base-model gpt-4.1-mini --endpoint $ENDPOINT --api-key $KEY1718# SFT JSONL to RFT JSONL (passthrough — same format, different intent)19python convert_dataset.py --input train.jsonl --output rft.jsonl --format rft2021# DPO JSONL to SFT (extract chosen responses)22python convert_dataset.py --input dpo.jsonl --output sft.jsonl --format sft-from-dpo23"""2425import json26import os27import sys2829try:30sys.stdout.reconfigure(encoding="utf-8")31sys.stderr.reconfigure(encoding="utf-8")32except (AttributeError, OSError):33pass # Stream not reconfigurable (older Python or non-tty); default encoding is fine34import time35sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))36from common import HelpOnErrorParser, get_clients373839def parquet_to_sft(input_path, output_path, user_col, assistant_col, system_prompt=None):40"""Convert a parquet or CSV file to SFT JSONL."""41try:42import pandas as pd43except ImportError:44print("Error: pandas required. Install with: pip install pandas pyarrow")45sys.exit(1)4647if input_path.endswith(".parquet"):48df = pd.read_parquet(input_path)49elif input_path.endswith(".csv"):50df = pd.read_csv(input_path)51elif input_path.endswith(".json"):52df = pd.read_json(input_path)53else:54print(f"Unsupported format: {input_path}. Use .parquet, .csv, or .json")55sys.exit(1)5657if user_col not in df.columns or assistant_col not in df.columns:58print(f"Error: Columns '{user_col}' and/or '{assistant_col}' not found.")59print(f"Available columns: {list(df.columns)}")60sys.exit(1)6162count = 063with open(output_path, "w", encoding="utf-8") as f:64for _, row in df.iterrows():65user_content = str(row[user_col]).strip()66asst_content = str(row[assistant_col]).strip()67if not user_content or not asst_content:68continue6970messages = []71if system_prompt:72messages.append({"role": "system", "content": system_prompt})73messages.append({"role": "user", "content": user_content})74messages.append({"role": "assistant", "content": asst_content})7576f.write(json.dumps({"messages": messages}, ensure_ascii=False) + "\n")77count += 17879print(f"Converted {count} examples to SFT JSONL → {output_path}")808182def sft_to_dpo(input_path, output_path, client, base_model):83"""Convert SFT to DPO by generating non-preferred responses from a base model.8485DPO format uses: input (system+user messages), preferred_output, non_preferred_output.86"""87with open(input_path, encoding="utf-8") as inf:88examples = []89for ln, raw in enumerate(inf, 1):90if not raw.strip():91continue92try:93examples.append(json.loads(raw))94except json.JSONDecodeError as e:95print(f" ⚠️ Skipping malformed JSON on line {ln}: {e}")96count = 09798with open(output_path, "w", encoding="utf-8") as f:99for i, ex in enumerate(examples):100msgs = ex["messages"]101system_msgs = [m for m in msgs if m["role"] == "system"]102user_msg = next((m for m in msgs if m["role"] == "user"), None)103asst_msg = next((m for m in msgs if m["role"] == "assistant"), None)104if not user_msg or not asst_msg:105continue106107# Generate a non-preferred response from the base model108try:109gen_msgs = system_msgs + [user_msg]110resp = client.chat.completions.create(111model=base_model,112messages=gen_msgs,113temperature=1.0, # High temp for diversity114max_completion_tokens=2048,115)116rejected_content = resp.choices[0].message.content117except Exception as e:118print(f" Skipping example {i}: {e}")119continue120121if not rejected_content:122# None or empty — content filter, finish=length with no text, etc.123# Skip rather than emit a DPO entry with null content (trainer rejects).124print(f" Skipping example {i}: base model returned no content")125continue126127# Build DPO entry with correct format128input_messages = system_msgs + [user_msg]129dpo_entry = {130"input": {"messages": input_messages},131"preferred_output": [asst_msg],132"non_preferred_output": [{"role": "assistant", "content": rejected_content}],133}134f.write(json.dumps(dpo_entry, ensure_ascii=False) + "\n")135count += 1136137if (i + 1) % 50 == 0:138print(f" Processed {i+1}/{len(examples)}")139time.sleep(1)140141print(f"Converted {count} examples to DPO JSONL → {output_path}")142143144def sft_to_rft(input_path, output_path):145"""Convert SFT to RFT format.146147Strips assistant messages (RFT last message must be user) and adds a148placeholder grader field. The user must populate grader reference fields149(e.g., expected_answer) before training.150"""151count = 0152skipped = 0153with open(output_path, "w", encoding="utf-8") as out:154with open(input_path, encoding="utf-8") as inf:155for ln, line in enumerate(inf, 1):156if not line.strip():157continue158try:159ex = json.loads(line)160except json.JSONDecodeError as e:161print(f" ⚠️ Skipping malformed JSON on line {ln}: {e}")162skipped += 1163continue164msgs = ex.get("messages", [])165# Keep only system + user messages; RFT last message must be user166rft_msgs = [m for m in msgs if m["role"] in ("system", "user")]167if not rft_msgs or rft_msgs[-1]["role"] != "user":168skipped += 1169continue170# Extract assistant content as a reference answer placeholder171asst_msgs = [m for m in msgs if m["role"] == "assistant"]172expected = asst_msgs[-1]["content"] if asst_msgs else ""173rft_entry = {"messages": rft_msgs, "expected_answer": expected}174out.write(json.dumps(rft_entry, ensure_ascii=False) + "\n")175count += 1176print(f"Converted {count} examples to RFT JSONL → {output_path}")177if skipped:178print(f" Skipped {skipped} examples (no user message)")179print("Note: Review 'expected_answer' fields and update your grader to use item.expected_answer.")180181182def dpo_to_sft(input_path, output_path, system_prompt=None):183"""Extract chosen responses from DPO format to SFT format."""184count = 0185with open(output_path, "w", encoding="utf-8") as f:186with open(input_path, encoding="utf-8") as inf:187for ln, line in enumerate(inf, 1):188if not line.strip():189continue190try:191ex = json.loads(line)192except json.JSONDecodeError as e:193print(f" ⚠️ Skipping malformed JSON on line {ln}: {e}")194continue195input_messages = ex["input"]["messages"]196chosen_messages = ex["preferred_output"]197198messages = []199if system_prompt:200messages.append({"role": "system", "content": system_prompt})201messages.extend(m for m in input_messages if m["role"] != "system")202else:203messages.extend(input_messages)204messages.extend(chosen_messages)205f.write(json.dumps({"messages": messages}, ensure_ascii=False) + "\n")206count += 1207print(f"Extracted {count} chosen examples to SFT JSONL → {output_path}")208209210def main():211parser = HelpOnErrorParser(description="Convert between fine-tuning dataset formats")212parser.add_argument("--input", required=True, help="Input file path")213parser.add_argument("--output", required=True, help="Output file path")214parser.add_argument("--format", required=True,215choices=["sft", "dpo", "rft", "sft-from-dpo"],216help="Target format")217218# SFT from raw data219parser.add_argument("--user-column", default="prompt", help="Column name for user input")220parser.add_argument("--assistant-column", default="response", help="Column name for assistant output")221parser.add_argument("--system-prompt", default=None, help="System prompt to prepend")222223# DPO generation (needs API connection)224parser.add_argument("--base-url", default=os.environ.get("OPENAI_BASE_URL"),225help="Project /v1/ URL (preferred)")226parser.add_argument("--endpoint", default=os.environ.get("AZURE_OPENAI_ENDPOINT"),227help="Azure OpenAI endpoint (fallback)")228parser.add_argument("--project-endpoint", default=os.environ.get("AZURE_AI_PROJECT_ENDPOINT"),229help="Azure AI project endpoint (Foundry SDK)")230parser.add_argument("--api-key", default=os.environ.get("AZURE_OPENAI_API_KEY"))231parser.add_argument("--base-model", default="gpt-4.1-mini", help="Base model for generating rejections")232233args = parser.parse_args()234235if args.format == "sft":236if args.input.endswith(".jsonl"):237print("Input is already JSONL — assuming SFT format. Nothing to convert.")238if args.input != args.output:239import shutil240shutil.copy2(args.input, args.output)241else:242parquet_to_sft(args.input, args.output, args.user_column,243args.assistant_column, args.system_prompt)244245elif args.format == "dpo":246client, method = get_clients(247base_url=args.base_url, azure_endpoint=args.endpoint,248project_endpoint=args.project_endpoint, api_key=args.api_key249)250sft_to_dpo(args.input, args.output, client, args.base_model)251252elif args.format == "rft":253sft_to_rft(args.input, args.output)254255elif args.format == "sft-from-dpo":256dpo_to_sft(args.input, args.output, args.system_prompt)257258259if __name__ == "__main__":260main()261