Loading source
Pulling the file list, source metadata, and syntax-aware rendering for this listing.
Source from repo
Deploy, evaluate, and manage AI agents end-to-end on Microsoft Azure AI Foundry
Files
Skill
Size
Entrypoint
Format
Open file
Syntax-highlighted preview of this file as included in the skill package.
finetuning/scripts/convert_dataset.py
1# /// script2# dependencies = [3# "openai>=1.0",4# ]5# ///6"""7convert_dataset.py — Convert between SFT, DPO, and RFT dataset formats.89Usage:10# Parquet/CSV to SFT JSONL11python convert_dataset.py --input data.parquet --output train.jsonl --format sft \12--user-column prompt --assistant-column response --system-prompt "You are helpful."1314# SFT JSONL to DPO (generates rejected via base model)15python convert_dataset.py --input train.jsonl --output dpo.jsonl --format dpo \16--base-model gpt-4.1-mini --endpoint $ENDPOINT --api-key $KEY1718# SFT JSONL to RFT JSONL (passthrough — same format, different intent)19python convert_dataset.py --input train.jsonl --output rft.jsonl --format rft2021# DPO JSONL to SFT (extract chosen responses)22python convert_dataset.py --input dpo.jsonl --output sft.jsonl --format sft-from-dpo23"""2425import json26import os27import sys2829try:30sys.stdout.reconfigure(encoding="utf-8")31sys.stderr.reconfigure(encoding="utf-8")32except (AttributeError, OSError):33pass # Stream not reconfigurable (older Python or non-tty); default encoding is fine34import time35sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))36from common import HelpOnErrorParser, get_clients373839def parquet_to_sft(input_path, output_path, user_col, assistant_col, system_prompt=None):40"""Convert a parquet or CSV file to SFT JSONL."""41try:42import pandas as pd43except ImportError:44print("Error: pandas required. Install with: pip install pandas pyarrow")45sys.exit(1)4647if input_path.endswith(".parquet"):48df = pd.read_parquet(input_path)49elif input_path.endswith(".csv"):50df = pd.read_csv(input_path)51elif input_path.endswith(".json"):52df = pd.read_json(input_path)53else:54print(f"Unsupported format: {input_path}. Use .parquet, .csv, or .json")55sys.exit(1)5657if user_col not in df.columns or assistant_col not in df.columns:58print(f"Error: Columns '{user_col}' and/or '{assistant_col}' not found.")59print(f"Available columns: {list(df.columns)}")60sys.exit(1)6162count = 063with open(output_path, "w", encoding="utf-8") as f:64for _, row in df.iterrows():65user_content = str(row[user_col]).strip()66asst_content = str(row[assistant_col]).strip()67if not user_content or not asst_content:68continue6970messages = []71if system_prompt:72messages.append({"role": "system", "content": system_prompt})73messages.append({"role": "user", "content": user_content})74messages.append({"role": "assistant", "content": asst_content})7576f.write(json.dumps({"messages": messages}, ensure_ascii=False) + "\n")77count += 17879print(f"Converted {count} examples to SFT JSONL → {output_path}")808182def sft_to_dpo(input_path, output_path, client, base_model):83"""Convert SFT to DPO by generating non-preferred responses from a base model.8485DPO format uses: input (system+user messages), preferred_output, non_preferred_output.86"""87with open(input_path, encoding="utf-8") as inf:88examples = []89for ln, raw in enumerate(inf, 1):90if not raw.strip():91continue92try:93examples.append(json.loads(raw))94except json.JSONDecodeError as e:95print(f" ⚠️ Skipping malformed JSON on line {ln}: {e}")96count = 09798with open(output_path, "w", encoding="utf-8") as f:99for i, ex in enumerate(examples):100msgs = ex["messages"]101system_msgs = [m for m in msgs if m["role"] == "system"]102user_msg = next((m for m in msgs if m["role"] == "user"), None)103asst_msg = next((m for m in msgs if m["role"] == "assistant"), None)104if not user_msg or not asst_msg:105continue106107# Generate a non-preferred response from the base model108try:109gen_msgs = system_msgs + [user_msg]110resp = client.chat.completions.create(111model=base_model,112messages=gen_msgs,113temperature=1.0, # High temp for diversity114max_completion_tokens=2048,115)116rejected_content = resp.choices[0].message.content117except Exception as e:118print(f" Skipping example {i}: {e}")119continue120121if not rejected_content:122# None or empty — content filter, finish=length with no text, etc.123# Skip rather than emit a DPO entry with null content (trainer rejects).124print(f" Skipping example {i}: base model returned no content")125continue126127# Build DPO entry with correct format128input_messages = system_msgs + [user_msg]129dpo_entry = {130"input": {"messages": input_messages},131"preferred_output": [asst_msg],132"non_preferred_output": [{"role": "assistant", "content": rejected_content}],133}134f.write(json.dumps(dpo_entry, ensure_ascii=False) + "\n")135count += 1136137if (i + 1) % 50 == 0:138print(f" Processed {i+1}/{len(examples)}")139time.sleep(1)140141print(f"Converted {count} examples to DPO JSONL → {output_path}")142143144def sft_to_rft(input_path, output_path):145"""Convert SFT to RFT format.146147Strips assistant messages (RFT last message must be user) and adds a148placeholder grader field. The user must populate grader reference fields149(e.g., expected_answer) before training.150"""151count = 0152skipped = 0153with open(output_path, "w", encoding="utf-8") as out:154with open(input_path, encoding="utf-8") as inf:155for ln, line in enumerate(inf, 1):156if not line.strip():157continue158try:159ex = json.loads(line)160except json.JSONDecodeError as e:161print(f" ⚠️ Skipping malformed JSON on line {ln}: {e}")162skipped += 1163continue164msgs = ex.get("messages", [])165# Keep only system + user messages; RFT last message must be user166rft_msgs = [m for m in msgs if m["role"] in ("system", "user")]167if not rft_msgs or rft_msgs[-1]["role"] != "user":168skipped += 1169continue170# Extract assistant content as a reference answer placeholder171asst_msgs = [m for m in msgs if m["role"] == "assistant"]172expected = asst_msgs[-1]["content"] if asst_msgs else ""173rft_entry = {"messages": rft_msgs, "expected_answer": expected}174out.write(json.dumps(rft_entry, ensure_ascii=False) + "\n")175count += 1176print(f"Converted {count} examples to RFT JSONL → {output_path}")177if skipped:178print(f" Skipped {skipped} examples (no user message)")179print("Note: Review 'expected_answer' fields and update your grader to use item.expected_answer.")180181182def dpo_to_sft(input_path, output_path, system_prompt=None):183"""Extract chosen responses from DPO format to SFT format."""184count = 0185with open(output_path, "w", encoding="utf-8") as f:186with open(input_path, encoding="utf-8") as inf:187for ln, line in enumerate(inf, 1):188if not line.strip():189continue190try:191ex = json.loads(line)192except json.JSONDecodeError as e:193print(f" ⚠️ Skipping malformed JSON on line {ln}: {e}")194continue195input_messages = ex["input"]["messages"]196chosen_messages = ex["preferred_output"]197198messages = []199if system_prompt:200messages.append({"role": "system", "content": system_prompt})201messages.extend(m for m in input_messages if m["role"] != "system")202else:203messages.extend(input_messages)204messages.extend(chosen_messages)205f.write(json.dumps({"messages": messages}, ensure_ascii=False) + "\n")206count += 1207print(f"Extracted {count} chosen examples to SFT JSONL → {output_path}")208209210def main():211parser = HelpOnErrorParser(description="Convert between fine-tuning dataset formats")212parser.add_argument("--input", required=True, help="Input file path")213parser.add_argument("--output", required=True, help="Output file path")214parser.add_argument("--format", required=True,215choices=["sft", "dpo", "rft", "sft-from-dpo"],216help="Target format")217218# SFT from raw data219parser.add_argument("--user-column", default="prompt", help="Column name for user input")220parser.add_argument("--assistant-column", default="response", help="Column name for assistant output")221parser.add_argument("--system-prompt", default=None, help="System prompt to prepend")222223# DPO generation (needs API connection)224parser.add_argument("--base-url", default=os.environ.get("OPENAI_BASE_URL"),225help="Project /v1/ URL (preferred)")226parser.add_argument("--endpoint", default=os.environ.get("AZURE_OPENAI_ENDPOINT"),227help="Azure OpenAI endpoint (fallback)")228parser.add_argument("--project-endpoint", default=os.environ.get("AZURE_AI_PROJECT_ENDPOINT"),229help="Azure AI project endpoint (Foundry SDK)")230parser.add_argument("--api-key", default=os.environ.get("AZURE_OPENAI_API_KEY"))231parser.add_argument("--base-model", default="gpt-4.1-mini", help="Base model for generating rejections")232233args = parser.parse_args()234235if args.format == "sft":236if args.input.endswith(".jsonl"):237print("Input is already JSONL — assuming SFT format. Nothing to convert.")238if args.input != args.output:239import shutil240shutil.copy2(args.input, args.output)241else:242parquet_to_sft(args.input, args.output, args.user_column,243args.assistant_column, args.system_prompt)244245elif args.format == "dpo":246client, method = get_clients(247base_url=args.base_url, azure_endpoint=args.endpoint,248project_endpoint=args.project_endpoint, api_key=args.api_key249)250sft_to_dpo(args.input, args.output, client, args.base_model)251252elif args.format == "rft":253sft_to_rft(args.input, args.output)254255elif args.format == "sft-from-dpo":256dpo_to_sft(args.input, args.output, args.system_prompt)257258259if __name__ == "__main__":260main()261