310 lines
12 KiB
Python
310 lines
12 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Nexus One AI — QLoRA Fine-Tuning Runner
|
|
Launched as a subprocess by the FastAPI backend.
|
|
|
|
Writes structured JSONL log lines to --log-path so the UI can stream
|
|
live loss curves and progress. Updates training_jobs.status in SQLite.
|
|
|
|
Requires (install on the training node):
|
|
pip install torch transformers datasets peft bitsandbytes trl
|
|
|
|
Optional (faster, lower VRAM):
|
|
pip install unsloth
|
|
|
|
Usage (called by main.py — do not run manually in production):
|
|
python3 train_qlora.py --job-id 1 --db-path /opt/cezen/data/cezen.db \
|
|
--dataset /opt/cezen/data/datasets/abc.jsonl \
|
|
--base-model mistral:7b --output-dir /opt/cezen/data/finetuned/mymodel \
|
|
--log-path /opt/cezen/data/job_logs/abc.jsonl \
|
|
--epochs 3 --lr 2e-4 --batch-size 4 --lora-r 16 --lora-alpha 32 \
|
|
--output-name mymodel
|
|
"""
|
|
|
|
import argparse, json, os, sqlite3, sys, time
|
|
from datetime import datetime, timezone
|
|
from pathlib import Path
|
|
|
|
# ── Argument parsing ──────────────────────────────────────────────────────────
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--job-id", type=int, required=True)
|
|
parser.add_argument("--db-path", required=True)
|
|
parser.add_argument("--dataset", required=True)
|
|
parser.add_argument("--base-model", required=True)
|
|
parser.add_argument("--output-dir", required=True)
|
|
parser.add_argument("--log-path", required=True)
|
|
parser.add_argument("--output-name", required=True)
|
|
parser.add_argument("--epochs", type=int, default=3)
|
|
parser.add_argument("--lr", type=float, default=2e-4)
|
|
parser.add_argument("--batch-size", type=int, default=4)
|
|
parser.add_argument("--lora-r", type=int, default=16)
|
|
parser.add_argument("--lora-alpha", type=int, default=32)
|
|
args = parser.parse_args()
|
|
|
|
# ── Helpers ───────────────────────────────────────────────────────────────────
|
|
|
|
def utcnow():
|
|
return datetime.now(timezone.utc).isoformat()
|
|
|
|
def db_connect():
|
|
conn = sqlite3.connect(args.db_path)
|
|
conn.row_factory = sqlite3.Row
|
|
return conn
|
|
|
|
log_file = open(args.log_path, "a", buffering=1)
|
|
|
|
def log(type_: str, **kwargs):
|
|
entry = {"ts": utcnow(), "type": type_, **kwargs}
|
|
log_file.write(json.dumps(entry) + "\n")
|
|
|
|
def set_status(status: str):
|
|
db = db_connect()
|
|
if status in ("completed", "failed", "cancelled"):
|
|
db.execute(
|
|
"UPDATE training_jobs SET status=?, finished_at=? WHERE id=?",
|
|
(status, utcnow(), args.job_id)
|
|
)
|
|
else:
|
|
db.execute("UPDATE training_jobs SET status=? WHERE id=?", (status, args.job_id))
|
|
db.commit()
|
|
db.close()
|
|
|
|
# ── Dataset loading ───────────────────────────────────────────────────────────
|
|
|
|
def load_dataset_from_file(path: str):
|
|
"""Load JSONL or CSV dataset into a list of dicts with 'text' or 'prompt'/'completion' keys."""
|
|
p = Path(path)
|
|
rows = []
|
|
if p.suffix.lower() == ".csv":
|
|
import csv
|
|
with open(path, newline="", encoding="utf-8", errors="replace") as f:
|
|
reader = csv.DictReader(f)
|
|
for row in reader:
|
|
rows.append(dict(row))
|
|
else:
|
|
with open(path, encoding="utf-8", errors="replace") as f:
|
|
for line in f:
|
|
line = line.strip()
|
|
if line:
|
|
try:
|
|
rows.append(json.loads(line))
|
|
except Exception:
|
|
pass
|
|
return rows
|
|
|
|
def format_row(row: dict) -> str:
|
|
"""Convert a dataset row to a plain text training string."""
|
|
if "text" in row:
|
|
return row["text"]
|
|
if "prompt" in row and "completion" in row:
|
|
return f"### Instruction:\n{row['prompt']}\n\n### Response:\n{row['completion']}"
|
|
if "instruction" in row and "output" in row:
|
|
inp = row.get("input", "")
|
|
return (f"### Instruction:\n{row['instruction']}\n\n### Input:\n{inp}\n\n### Response:\n{row['output']}"
|
|
if inp else
|
|
f"### Instruction:\n{row['instruction']}\n\n### Response:\n{row['output']}")
|
|
# Fallback: concatenate all values
|
|
return " ".join(str(v) for v in row.values() if v)
|
|
|
|
# ── Main training routine ─────────────────────────────────────────────────────
|
|
|
|
def main():
|
|
log("start", job_id=args.job_id, base_model=args.base_model,
|
|
epochs=args.epochs, lr=args.lr, batch_size=args.batch_size,
|
|
lora_r=args.lora_r, lora_alpha=args.lora_alpha)
|
|
set_status("running")
|
|
|
|
# Resolve model name (Ollama uses "mistral:7b" style — strip the tag for HF)
|
|
hf_model = args.base_model
|
|
if ":" in hf_model and "/" not in hf_model:
|
|
# e.g. "mistral:7b" → try to map to HF repo
|
|
name_map = {
|
|
"mistral": "mistralai/Mistral-7B-v0.1",
|
|
"llama2": "meta-llama/Llama-2-7b-hf",
|
|
"llama3": "meta-llama/Meta-Llama-3-8B",
|
|
"phi3": "microsoft/Phi-3-mini-4k-instruct",
|
|
"gemma": "google/gemma-7b",
|
|
"codellama":"codellama/CodeLlama-7b-hf",
|
|
"qwen2": "Qwen/Qwen2-7B",
|
|
}
|
|
base_name = hf_model.split(":")[0].lower()
|
|
hf_model = name_map.get(base_name, hf_model)
|
|
log("info", msg=f"Mapped '{args.base_model}' → '{hf_model}' (HuggingFace)")
|
|
|
|
# Load dataset
|
|
log("info", msg="Loading dataset...")
|
|
raw_rows = load_dataset_from_file(args.dataset)
|
|
if not raw_rows:
|
|
log("error", msg="Dataset is empty or could not be parsed")
|
|
set_status("failed")
|
|
sys.exit(1)
|
|
|
|
texts = [format_row(r) for r in raw_rows]
|
|
log("info", msg=f"Loaded {len(texts)} training examples")
|
|
|
|
# Try Unsloth first (faster), fall back to HF PEFT
|
|
use_unsloth = False
|
|
try:
|
|
from unsloth import FastLanguageModel
|
|
use_unsloth = True
|
|
log("info", msg="Using Unsloth for accelerated training")
|
|
except ImportError:
|
|
log("info", msg="Unsloth not available — using HuggingFace PEFT + BitsAndBytes")
|
|
|
|
try:
|
|
import torch
|
|
from transformers import TrainingArguments, TrainerCallback
|
|
from datasets import Dataset as HFDataset
|
|
|
|
if use_unsloth:
|
|
model, tokenizer = FastLanguageModel.from_pretrained(
|
|
model_name=hf_model,
|
|
max_seq_length=2048,
|
|
dtype=None,
|
|
load_in_4bit=True,
|
|
)
|
|
model = FastLanguageModel.get_peft_model(
|
|
model,
|
|
r=args.lora_r,
|
|
lora_alpha=args.lora_alpha,
|
|
target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"],
|
|
lora_dropout=0,
|
|
bias="none",
|
|
use_gradient_checkpointing="unsloth",
|
|
)
|
|
else:
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
|
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
|
|
|
|
bnb_cfg = BitsAndBytesConfig(
|
|
load_in_4bit=True,
|
|
bnb_4bit_quant_type="nf4",
|
|
bnb_4bit_compute_dtype=torch.float16,
|
|
)
|
|
tokenizer = AutoTokenizer.from_pretrained(hf_model, trust_remote_code=True)
|
|
if tokenizer.pad_token is None:
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
hf_model,
|
|
quantization_config=bnb_cfg,
|
|
device_map="auto",
|
|
trust_remote_code=True,
|
|
)
|
|
model = prepare_model_for_kbit_training(model)
|
|
|
|
lora_cfg = LoraConfig(
|
|
r=args.lora_r,
|
|
lora_alpha=args.lora_alpha,
|
|
target_modules=["q_proj","k_proj","v_proj","o_proj"],
|
|
lora_dropout=0.05,
|
|
bias="none",
|
|
task_type="CAUSAL_LM",
|
|
)
|
|
model = get_peft_model(model, lora_cfg)
|
|
|
|
# Tokenise
|
|
def tokenise(examples):
|
|
return tokenizer(
|
|
examples["text"],
|
|
truncation=True,
|
|
max_length=2048,
|
|
padding="max_length",
|
|
)
|
|
|
|
hf_ds = HFDataset.from_dict({"text": texts})
|
|
hf_ds = hf_ds.map(tokenise, batched=True, remove_columns=["text"])
|
|
|
|
# Custom callback to stream loss to our log
|
|
class LossLogger(TrainerCallback):
|
|
def on_log(self, _args, state, control, logs=None, **kwargs):
|
|
if logs and "loss" in logs:
|
|
log("loss",
|
|
step=state.global_step,
|
|
loss=round(float(logs["loss"]), 6),
|
|
epoch=round(float(logs.get("epoch", 0)), 3),
|
|
lr=float(logs.get("learning_rate", args.lr)))
|
|
|
|
output_dir = args.output_dir
|
|
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
|
|
|
from trl import SFTTrainer
|
|
trainer = SFTTrainer(
|
|
model=model,
|
|
tokenizer=tokenizer,
|
|
train_dataset=hf_ds,
|
|
dataset_text_field="input_ids",
|
|
max_seq_length=2048,
|
|
args=TrainingArguments(
|
|
output_dir=output_dir,
|
|
num_train_epochs=args.epochs,
|
|
per_device_train_batch_size=args.batch_size,
|
|
gradient_accumulation_steps=4,
|
|
warmup_steps=5,
|
|
learning_rate=args.lr,
|
|
fp16=not torch.cuda.is_bf16_supported(),
|
|
bf16=torch.cuda.is_bf16_supported(),
|
|
logging_steps=1,
|
|
save_strategy="epoch",
|
|
report_to="none",
|
|
),
|
|
callbacks=[LossLogger()],
|
|
)
|
|
|
|
log("info", msg="Training started")
|
|
trainer.train()
|
|
log("info", msg="Training complete — saving model")
|
|
trainer.save_model(output_dir)
|
|
tokenizer.save_pretrained(output_dir)
|
|
|
|
except Exception as e:
|
|
import traceback
|
|
log("error", msg=str(e), traceback=traceback.format_exc())
|
|
set_status("failed")
|
|
sys.exit(1)
|
|
|
|
# Auto-register with Ollama via Modelfile
|
|
try:
|
|
_register_with_ollama(output_dir, args.output_name)
|
|
except Exception as e:
|
|
log("warning", msg=f"Could not auto-register with Ollama: {e}")
|
|
|
|
log("complete", msg="Job finished successfully", output_dir=output_dir)
|
|
set_status("completed")
|
|
|
|
|
|
def _register_with_ollama(model_dir: str, model_name: str):
|
|
"""Create an Ollama Modelfile and register the fine-tuned model."""
|
|
modelfile_path = Path(model_dir) / "Modelfile"
|
|
modelfile_path.write_text(
|
|
f'FROM {model_dir}\n'
|
|
f'PARAMETER stop "<|im_end|>"\n'
|
|
f'SYSTEM "This is a Nexus One AI fine-tuned model."\n'
|
|
)
|
|
import subprocess
|
|
result = subprocess.run(
|
|
["ollama", "create", model_name, "-f", str(modelfile_path)],
|
|
capture_output=True, text=True, timeout=300
|
|
)
|
|
if result.returncode == 0:
|
|
log("info", msg=f"Model '{model_name}' registered with Ollama")
|
|
else:
|
|
log("warning", msg=f"Ollama registration failed: {result.stderr}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
try:
|
|
main()
|
|
except KeyboardInterrupt:
|
|
log("error", msg="Job interrupted (SIGTERM/SIGINT)")
|
|
set_status("cancelled")
|
|
sys.exit(130)
|
|
except Exception as e:
|
|
import traceback
|
|
log("error", msg=str(e), traceback=traceback.format_exc())
|
|
set_status("failed")
|
|
sys.exit(1)
|
|
finally:
|
|
log_file.close()
|