Mom-Portal/scripts/whisper_http_server.py
KevinB-T 9517bad3dc feat: add WhisperX diarization and speaker transcript UI
- add WhisperX diarization support to the Whisper VM server
- normalize speaker timestamp segments from Whisper responses
- document Hugging Face/pyannote VM setup and health checks
- show diarized speaker transcript blocks in record and transcript views
- group consecutive segments from the same speaker
- remove duplicate paragraph transcript display when diarized segments exist
- let diarized transcript content expand without an inner scrollbar
2026-05-20 16:34:50 +05:30

421 lines
14 KiB
Python

#!/usr/bin/env python3
"""Tiny Faster-Whisper/WhisperX HTTP API for Orphion.
Endpoints:
GET /health
POST /transcribe multipart form field "file"
"""
from __future__ import annotations
import argparse
import cgi
import gc
import inspect
import json
import os
import tempfile
import threading
import traceback
from http import HTTPStatus
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from pathlib import Path
from faster_whisper import WhisperModel
try:
import torch
except Exception: # pragma: no cover - torch is optional for CPU fallback
torch = None
try:
import whisperx
except Exception: # pragma: no cover - keeps existing Faster-Whisper path available
whisperx = None
try:
import whisperx.diarize as whisperx_diarize
except Exception: # pragma: no cover - older/newer WhisperX layouts vary
whisperx_diarize = None
try:
import ctranslate2
except Exception: # pragma: no cover - only used for runtime device selection
ctranslate2 = None
MODEL_LOCK = threading.Lock()
MODEL = None
WHISPERX_LOCK = threading.Lock()
INFERENCE_LOCK = threading.Lock()
WHISPERX_MODEL = None
ALIGN_MODELS = {}
DIARIZATION_PIPELINE = None
def choose_device() -> str:
configured = os.getenv("WHISPER_DEVICE", "").strip()
if configured:
return configured
if ctranslate2 is not None:
try:
if ctranslate2.get_cuda_device_count() > 0:
return "cuda"
except Exception:
pass
return "cpu"
def default_compute_type(device: str) -> str:
configured = os.getenv("WHISPER_COMPUTE_TYPE", "").strip()
if configured:
return configured
return "float16" if device.startswith("cuda") else "int8"
def diarization_enabled() -> bool:
return os.getenv("WHISPERX_DIARIZATION", "true").lower() in {"1", "true", "yes", "on"}
def whisperx_device() -> str:
configured = os.getenv("WHISPERX_DEVICE", "").strip()
if configured:
return configured
selected = choose_device()
return "cuda" if selected == "cuda" else selected
def whisperx_compute_type(device: str) -> str:
configured = os.getenv("WHISPERX_COMPUTE_TYPE", "").strip()
if configured:
return configured
return "float16" if device.startswith("cuda") else default_compute_type(device)
def huggingface_token() -> str | None:
return (
os.getenv("HUGGINGFACE_TOKEN")
or os.getenv("HF_TOKEN")
or os.getenv("PYANNOTE_AUTH_TOKEN")
or None
)
def get_model() -> WhisperModel:
global MODEL
if MODEL is not None:
return MODEL
with MODEL_LOCK:
if MODEL is None:
model_name = os.getenv("WHISPER_MODEL", "large-v3")
device = choose_device()
MODEL = WhisperModel(
model_name,
device=device,
compute_type=default_compute_type(device),
download_root=os.getenv("WHISPER_MODEL_DIR") or None,
)
return MODEL
def get_whisperx_model():
global WHISPERX_MODEL
if WHISPERX_MODEL is not None:
return WHISPERX_MODEL
if whisperx is None:
raise RuntimeError("whisperx is not installed")
with WHISPERX_LOCK:
if WHISPERX_MODEL is None:
device = whisperx_device()
model_name = os.getenv("WHISPERX_MODEL", os.getenv("WHISPER_MODEL", "large-v3"))
kwargs = {
"compute_type": whisperx_compute_type(device),
"download_root": os.getenv("WHISPER_MODEL_DIR") or None,
}
try:
WHISPERX_MODEL = whisperx.load_model(model_name, device, **kwargs)
except TypeError:
kwargs.pop("download_root", None)
WHISPERX_MODEL = whisperx.load_model(model_name, device, **kwargs)
return WHISPERX_MODEL
def get_align_model(language_code: str | None, device: str):
if whisperx is None:
raise RuntimeError("whisperx is not installed")
language = language_code or os.getenv("WHISPERX_ALIGN_LANGUAGE", "en")
key = (language, device)
if key not in ALIGN_MODELS:
ALIGN_MODELS[key] = whisperx.load_align_model(language_code=language, device=device)
return ALIGN_MODELS[key]
def get_diarization_pipeline(device: str):
global DIARIZATION_PIPELINE
if whisperx is None:
raise RuntimeError("whisperx is not installed")
if DIARIZATION_PIPELINE is not None:
return DIARIZATION_PIPELINE
token = huggingface_token()
if not token:
raise RuntimeError("HuggingFace token is required for speaker diarization")
with WHISPERX_LOCK:
if DIARIZATION_PIPELINE is None:
pipeline_factory = getattr(whisperx, "DiarizationPipeline", None)
if pipeline_factory is None and whisperx_diarize is not None:
pipeline_factory = getattr(whisperx_diarize, "DiarizationPipeline", None)
if pipeline_factory is None:
raise RuntimeError("WhisperX diarization pipeline is not available")
parameters = inspect.signature(pipeline_factory).parameters
if "use_auth_token" in parameters:
DIARIZATION_PIPELINE = pipeline_factory(use_auth_token=token, device=device)
elif "auth_token" in parameters:
DIARIZATION_PIPELINE = pipeline_factory(auth_token=token, device=device)
elif "token" in parameters:
DIARIZATION_PIPELINE = pipeline_factory(
model_name=os.getenv("WHISPERX_DIARIZATION_MODEL") or None,
token=token,
device=device,
)
else:
DIARIZATION_PIPELINE = pipeline_factory(device=device)
return DIARIZATION_PIPELINE
def cleanup_gpu_memory() -> None:
gc.collect()
if torch is not None and getattr(torch, "cuda", None) is not None:
try:
if torch.cuda.is_available():
torch.cuda.empty_cache()
except Exception:
pass
def normalize_speaker(raw: str | None, mapping: dict[str, str]) -> str:
if not raw:
return "Speaker 1"
key = str(raw)
if key not in mapping:
mapping[key] = f"Speaker {len(mapping) + 1}"
return mapping[key]
def segment_speaker(segment: dict) -> str | None:
if segment.get("speaker"):
return segment.get("speaker")
counts: dict[str, int] = {}
for word in segment.get("words") or []:
speaker = word.get("speaker")
if speaker:
counts[speaker] = counts.get(speaker, 0) + 1
if not counts:
return None
return max(counts.items(), key=lambda item: item[1])[0]
def normalize_segments(segments: list[dict], include_speakers: bool = True) -> list[dict]:
speaker_map: dict[str, str] = {}
normalized = []
for segment in segments:
text = str(segment.get("text") or "").strip()
if not text:
continue
item = {
"start": round(float(segment.get("start") or 0), 3),
"end": round(float(segment.get("end") or 0), 3),
"text": text,
}
if include_speakers:
item["speaker"] = normalize_speaker(segment_speaker(segment), speaker_map)
normalized.append(item)
return normalized
def transcribe_with_whisperx(audio_path: str) -> dict:
device = whisperx_device()
model = get_whisperx_model()
audio = whisperx.load_audio(audio_path)
batch_size = int(os.getenv("WHISPERX_BATCH_SIZE", os.getenv("WHISPER_BATCH_SIZE", "8")))
result = model.transcribe(audio, batch_size=batch_size)
language = result.get("language")
try:
align_model, metadata = get_align_model(language, device)
result = whisperx.align(
result.get("segments", []),
align_model,
metadata,
audio,
device,
return_char_alignments=False,
)
except Exception:
traceback.print_exc()
diarization_failed = False
try:
diarize_model = get_diarization_pipeline(device)
min_speakers = os.getenv("WHISPERX_MIN_SPEAKERS")
max_speakers = os.getenv("WHISPERX_MAX_SPEAKERS")
kwargs = {}
if min_speakers:
kwargs["min_speakers"] = int(min_speakers)
if max_speakers:
kwargs["max_speakers"] = int(max_speakers)
diarize_segments = diarize_model(audio, **kwargs)
result = whisperx.assign_word_speakers(diarize_segments, result)
except Exception:
diarization_failed = True
traceback.print_exc()
timestamps = normalize_segments(result.get("segments", []), include_speakers=not diarization_failed)
transcript_text = " ".join(segment["text"] for segment in timestamps if segment.get("text"))
return {
"transcript_text": transcript_text,
"language": language,
"duration": float(audio.shape[0]) / 16000 if hasattr(audio, "shape") else None,
"timestamps": timestamps,
"diarization": "fallback" if diarization_failed else "completed",
}
def transcribe_with_faster_whisper(audio_path: str) -> dict:
model = get_model()
segments, info = model.transcribe(
audio_path,
beam_size=int(os.getenv("WHISPER_BEAM_SIZE", "5")),
vad_filter=os.getenv("WHISPER_VAD_FILTER", "true").lower() in {"1", "true", "yes", "on"},
)
timestamps = []
transcript_parts = []
for segment in segments:
text = segment.text.strip()
transcript_parts.append(text)
timestamps.append({"start": segment.start, "end": segment.end, "text": text})
return {
"transcript_text": " ".join(part for part in transcript_parts if part),
"language": getattr(info, "language", None),
"duration": getattr(info, "duration", None),
"timestamps": timestamps,
"diarization": "disabled",
}
class WhisperHandler(BaseHTTPRequestHandler):
server_version = "OrphionWhisper/1.0"
def do_GET(self) -> None:
if self.path.rstrip("/") != "/health":
self.send_json({"error": "not found"}, HTTPStatus.NOT_FOUND)
return
self.send_json(
{
"status": "ok",
"model": os.getenv("WHISPER_MODEL", "large-v3"),
"device": choose_device(),
"whisperx": whisperx is not None,
"diarization": diarization_enabled(),
"diarization_ready": whisperx is not None
and diarization_enabled()
and bool(huggingface_token()),
},
)
def do_POST(self) -> None:
if self.path.rstrip("/") != "/transcribe":
self.send_json({"error": "not found"}, HTTPStatus.NOT_FOUND)
return
try:
self.handle_transcribe()
except Exception as exc:
traceback.print_exc()
self.send_json({"error": str(exc)}, HTTPStatus.INTERNAL_SERVER_ERROR)
def handle_transcribe(self) -> None:
content_type = self.headers.get("content-type", "")
if not content_type.startswith("multipart/form-data"):
self.send_json({"error": "multipart/form-data required"}, HTTPStatus.BAD_REQUEST)
return
form = cgi.FieldStorage(
fp=self.rfile,
headers=self.headers,
environ={
"REQUEST_METHOD": "POST",
"CONTENT_TYPE": content_type,
"CONTENT_LENGTH": self.headers.get("content-length", "0"),
},
)
field_name = os.getenv("WHISPER_FILE_FIELD", "file")
upload = form[field_name] if field_name in form else None
if upload is None or not getattr(upload, "file", None):
self.send_json({"error": f"missing multipart field '{field_name}'"}, HTTPStatus.BAD_REQUEST)
return
suffix = Path(getattr(upload, "filename", "") or "audio").suffix
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file:
temp_path = temp_file.name
while True:
chunk = upload.file.read(1024 * 1024)
if not chunk:
break
temp_file.write(chunk)
try:
with INFERENCE_LOCK:
try:
if diarization_enabled():
payload = transcribe_with_whisperx(temp_path)
else:
payload = transcribe_with_faster_whisper(temp_path)
except Exception:
traceback.print_exc()
payload = transcribe_with_faster_whisper(temp_path)
payload["diarization"] = "fallback"
self.send_json(payload)
finally:
cleanup_gpu_memory()
try:
os.unlink(temp_path)
except OSError:
pass
def log_message(self, fmt: str, *args: object) -> None:
print(f"{self.address_string()} - {fmt % args}", flush=True)
def send_json(self, payload: dict, status: HTTPStatus = HTTPStatus.OK) -> None:
body = json.dumps(payload).encode("utf-8")
self.send_response(status)
self.send_header("content-type", "application/json")
self.send_header("content-length", str(len(body)))
self.end_headers()
self.wfile.write(body)
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--host", default=os.getenv("WHISPER_HOST", "0.0.0.0"))
parser.add_argument("--port", default=int(os.getenv("WHISPER_PORT", "8000")), type=int)
args = parser.parse_args()
server = ThreadingHTTPServer((args.host, args.port), WhisperHandler)
print(f"Whisper API listening on http://{args.host}:{args.port}", flush=True)
server.serve_forever()
if __name__ == "__main__":
main()