aipackage/ansible/roles/cezen-backend/files/train_qlora.py

310 lines
12 KiB
Python

#!/usr/bin/env python3
"""
Nexus One AI — QLoRA Fine-Tuning Runner
Launched as a subprocess by the FastAPI backend.
Writes structured JSONL log lines to --log-path so the UI can stream
live loss curves and progress. Updates training_jobs.status in SQLite.
Requires (install on the training node):
pip install torch transformers datasets peft bitsandbytes trl
Optional (faster, lower VRAM):
pip install unsloth
Usage (called by main.py — do not run manually in production):
python3 train_qlora.py --job-id 1 --db-path /opt/cezen/data/cezen.db \
--dataset /opt/cezen/data/datasets/abc.jsonl \
--base-model mistral:7b --output-dir /opt/cezen/data/finetuned/mymodel \
--log-path /opt/cezen/data/job_logs/abc.jsonl \
--epochs 3 --lr 2e-4 --batch-size 4 --lora-r 16 --lora-alpha 32 \
--output-name mymodel
"""
import argparse, json, os, sqlite3, sys, time
from datetime import datetime, timezone
from pathlib import Path
# ── Argument parsing ──────────────────────────────────────────────────────────
parser = argparse.ArgumentParser()
parser.add_argument("--job-id", type=int, required=True)
parser.add_argument("--db-path", required=True)
parser.add_argument("--dataset", required=True)
parser.add_argument("--base-model", required=True)
parser.add_argument("--output-dir", required=True)
parser.add_argument("--log-path", required=True)
parser.add_argument("--output-name", required=True)
parser.add_argument("--epochs", type=int, default=3)
parser.add_argument("--lr", type=float, default=2e-4)
parser.add_argument("--batch-size", type=int, default=4)
parser.add_argument("--lora-r", type=int, default=16)
parser.add_argument("--lora-alpha", type=int, default=32)
args = parser.parse_args()
# ── Helpers ───────────────────────────────────────────────────────────────────
def utcnow():
return datetime.now(timezone.utc).isoformat()
def db_connect():
conn = sqlite3.connect(args.db_path)
conn.row_factory = sqlite3.Row
return conn
log_file = open(args.log_path, "a", buffering=1)
def log(type_: str, **kwargs):
entry = {"ts": utcnow(), "type": type_, **kwargs}
log_file.write(json.dumps(entry) + "\n")
def set_status(status: str):
db = db_connect()
if status in ("completed", "failed", "cancelled"):
db.execute(
"UPDATE training_jobs SET status=?, finished_at=? WHERE id=?",
(status, utcnow(), args.job_id)
)
else:
db.execute("UPDATE training_jobs SET status=? WHERE id=?", (status, args.job_id))
db.commit()
db.close()
# ── Dataset loading ───────────────────────────────────────────────────────────
def load_dataset_from_file(path: str):
"""Load JSONL or CSV dataset into a list of dicts with 'text' or 'prompt'/'completion' keys."""
p = Path(path)
rows = []
if p.suffix.lower() == ".csv":
import csv
with open(path, newline="", encoding="utf-8", errors="replace") as f:
reader = csv.DictReader(f)
for row in reader:
rows.append(dict(row))
else:
with open(path, encoding="utf-8", errors="replace") as f:
for line in f:
line = line.strip()
if line:
try:
rows.append(json.loads(line))
except Exception:
pass
return rows
def format_row(row: dict) -> str:
"""Convert a dataset row to a plain text training string."""
if "text" in row:
return row["text"]
if "prompt" in row and "completion" in row:
return f"### Instruction:\n{row['prompt']}\n\n### Response:\n{row['completion']}"
if "instruction" in row and "output" in row:
inp = row.get("input", "")
return (f"### Instruction:\n{row['instruction']}\n\n### Input:\n{inp}\n\n### Response:\n{row['output']}"
if inp else
f"### Instruction:\n{row['instruction']}\n\n### Response:\n{row['output']}")
# Fallback: concatenate all values
return " ".join(str(v) for v in row.values() if v)
# ── Main training routine ─────────────────────────────────────────────────────
def main():
log("start", job_id=args.job_id, base_model=args.base_model,
epochs=args.epochs, lr=args.lr, batch_size=args.batch_size,
lora_r=args.lora_r, lora_alpha=args.lora_alpha)
set_status("running")
# Resolve model name (Ollama uses "mistral:7b" style — strip the tag for HF)
hf_model = args.base_model
if ":" in hf_model and "/" not in hf_model:
# e.g. "mistral:7b" → try to map to HF repo
name_map = {
"mistral": "mistralai/Mistral-7B-v0.1",
"llama2": "meta-llama/Llama-2-7b-hf",
"llama3": "meta-llama/Meta-Llama-3-8B",
"phi3": "microsoft/Phi-3-mini-4k-instruct",
"gemma": "google/gemma-7b",
"codellama":"codellama/CodeLlama-7b-hf",
"qwen2": "Qwen/Qwen2-7B",
}
base_name = hf_model.split(":")[0].lower()
hf_model = name_map.get(base_name, hf_model)
log("info", msg=f"Mapped '{args.base_model}''{hf_model}' (HuggingFace)")
# Load dataset
log("info", msg="Loading dataset...")
raw_rows = load_dataset_from_file(args.dataset)
if not raw_rows:
log("error", msg="Dataset is empty or could not be parsed")
set_status("failed")
sys.exit(1)
texts = [format_row(r) for r in raw_rows]
log("info", msg=f"Loaded {len(texts)} training examples")
# Try Unsloth first (faster), fall back to HF PEFT
use_unsloth = False
try:
from unsloth import FastLanguageModel
use_unsloth = True
log("info", msg="Using Unsloth for accelerated training")
except ImportError:
log("info", msg="Unsloth not available — using HuggingFace PEFT + BitsAndBytes")
try:
import torch
from transformers import TrainingArguments, TrainerCallback
from datasets import Dataset as HFDataset
if use_unsloth:
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=hf_model,
max_seq_length=2048,
dtype=None,
load_in_4bit=True,
)
model = FastLanguageModel.get_peft_model(
model,
r=args.lora_r,
lora_alpha=args.lora_alpha,
target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"],
lora_dropout=0,
bias="none",
use_gradient_checkpointing="unsloth",
)
else:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
bnb_cfg = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
)
tokenizer = AutoTokenizer.from_pretrained(hf_model, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
hf_model,
quantization_config=bnb_cfg,
device_map="auto",
trust_remote_code=True,
)
model = prepare_model_for_kbit_training(model)
lora_cfg = LoraConfig(
r=args.lora_r,
lora_alpha=args.lora_alpha,
target_modules=["q_proj","k_proj","v_proj","o_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, lora_cfg)
# Tokenise
def tokenise(examples):
return tokenizer(
examples["text"],
truncation=True,
max_length=2048,
padding="max_length",
)
hf_ds = HFDataset.from_dict({"text": texts})
hf_ds = hf_ds.map(tokenise, batched=True, remove_columns=["text"])
# Custom callback to stream loss to our log
class LossLogger(TrainerCallback):
def on_log(self, _args, state, control, logs=None, **kwargs):
if logs and "loss" in logs:
log("loss",
step=state.global_step,
loss=round(float(logs["loss"]), 6),
epoch=round(float(logs.get("epoch", 0)), 3),
lr=float(logs.get("learning_rate", args.lr)))
output_dir = args.output_dir
Path(output_dir).mkdir(parents=True, exist_ok=True)
from trl import SFTTrainer
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=hf_ds,
dataset_text_field="input_ids",
max_seq_length=2048,
args=TrainingArguments(
output_dir=output_dir,
num_train_epochs=args.epochs,
per_device_train_batch_size=args.batch_size,
gradient_accumulation_steps=4,
warmup_steps=5,
learning_rate=args.lr,
fp16=not torch.cuda.is_bf16_supported(),
bf16=torch.cuda.is_bf16_supported(),
logging_steps=1,
save_strategy="epoch",
report_to="none",
),
callbacks=[LossLogger()],
)
log("info", msg="Training started")
trainer.train()
log("info", msg="Training complete — saving model")
trainer.save_model(output_dir)
tokenizer.save_pretrained(output_dir)
except Exception as e:
import traceback
log("error", msg=str(e), traceback=traceback.format_exc())
set_status("failed")
sys.exit(1)
# Auto-register with Ollama via Modelfile
try:
_register_with_ollama(output_dir, args.output_name)
except Exception as e:
log("warning", msg=f"Could not auto-register with Ollama: {e}")
log("complete", msg="Job finished successfully", output_dir=output_dir)
set_status("completed")
def _register_with_ollama(model_dir: str, model_name: str):
"""Create an Ollama Modelfile and register the fine-tuned model."""
modelfile_path = Path(model_dir) / "Modelfile"
modelfile_path.write_text(
f'FROM {model_dir}\n'
f'PARAMETER stop "<|im_end|>"\n'
f'SYSTEM "This is a Nexus One AI fine-tuned model."\n'
)
import subprocess
result = subprocess.run(
["ollama", "create", model_name, "-f", str(modelfile_path)],
capture_output=True, text=True, timeout=300
)
if result.returncode == 0:
log("info", msg=f"Model '{model_name}' registered with Ollama")
else:
log("warning", msg=f"Ollama registration failed: {result.stderr}")
if __name__ == "__main__":
try:
main()
except KeyboardInterrupt:
log("error", msg="Job interrupted (SIGTERM/SIGINT)")
set_status("cancelled")
sys.exit(130)
except Exception as e:
import traceback
log("error", msg=str(e), traceback=traceback.format_exc())
set_status("failed")
sys.exit(1)
finally:
log_file.close()