- add WhisperX diarization support to the Whisper VM server - normalize speaker timestamp segments from Whisper responses - document Hugging Face/pyannote VM setup and health checks - show diarized speaker transcript blocks in record and transcript views - group consecutive segments from the same speaker - remove duplicate paragraph transcript display when diarized segments exist - let diarized transcript content expand without an inner scrollbar
421 lines
14 KiB
Python
421 lines
14 KiB
Python
#!/usr/bin/env python3
|
|
"""Tiny Faster-Whisper/WhisperX HTTP API for Orphion.
|
|
|
|
Endpoints:
|
|
GET /health
|
|
POST /transcribe multipart form field "file"
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import cgi
|
|
import gc
|
|
import inspect
|
|
import json
|
|
import os
|
|
import tempfile
|
|
import threading
|
|
import traceback
|
|
from http import HTTPStatus
|
|
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
|
from pathlib import Path
|
|
|
|
from faster_whisper import WhisperModel
|
|
|
|
try:
|
|
import torch
|
|
except Exception: # pragma: no cover - torch is optional for CPU fallback
|
|
torch = None
|
|
|
|
try:
|
|
import whisperx
|
|
except Exception: # pragma: no cover - keeps existing Faster-Whisper path available
|
|
whisperx = None
|
|
|
|
try:
|
|
import whisperx.diarize as whisperx_diarize
|
|
except Exception: # pragma: no cover - older/newer WhisperX layouts vary
|
|
whisperx_diarize = None
|
|
|
|
try:
|
|
import ctranslate2
|
|
except Exception: # pragma: no cover - only used for runtime device selection
|
|
ctranslate2 = None
|
|
|
|
|
|
MODEL_LOCK = threading.Lock()
|
|
MODEL = None
|
|
WHISPERX_LOCK = threading.Lock()
|
|
INFERENCE_LOCK = threading.Lock()
|
|
WHISPERX_MODEL = None
|
|
ALIGN_MODELS = {}
|
|
DIARIZATION_PIPELINE = None
|
|
|
|
|
|
def choose_device() -> str:
|
|
configured = os.getenv("WHISPER_DEVICE", "").strip()
|
|
if configured:
|
|
return configured
|
|
|
|
if ctranslate2 is not None:
|
|
try:
|
|
if ctranslate2.get_cuda_device_count() > 0:
|
|
return "cuda"
|
|
except Exception:
|
|
pass
|
|
|
|
return "cpu"
|
|
|
|
|
|
def default_compute_type(device: str) -> str:
|
|
configured = os.getenv("WHISPER_COMPUTE_TYPE", "").strip()
|
|
if configured:
|
|
return configured
|
|
return "float16" if device.startswith("cuda") else "int8"
|
|
|
|
|
|
def diarization_enabled() -> bool:
|
|
return os.getenv("WHISPERX_DIARIZATION", "true").lower() in {"1", "true", "yes", "on"}
|
|
|
|
|
|
def whisperx_device() -> str:
|
|
configured = os.getenv("WHISPERX_DEVICE", "").strip()
|
|
if configured:
|
|
return configured
|
|
selected = choose_device()
|
|
return "cuda" if selected == "cuda" else selected
|
|
|
|
|
|
def whisperx_compute_type(device: str) -> str:
|
|
configured = os.getenv("WHISPERX_COMPUTE_TYPE", "").strip()
|
|
if configured:
|
|
return configured
|
|
return "float16" if device.startswith("cuda") else default_compute_type(device)
|
|
|
|
|
|
def huggingface_token() -> str | None:
|
|
return (
|
|
os.getenv("HUGGINGFACE_TOKEN")
|
|
or os.getenv("HF_TOKEN")
|
|
or os.getenv("PYANNOTE_AUTH_TOKEN")
|
|
or None
|
|
)
|
|
|
|
|
|
def get_model() -> WhisperModel:
|
|
global MODEL
|
|
if MODEL is not None:
|
|
return MODEL
|
|
|
|
with MODEL_LOCK:
|
|
if MODEL is None:
|
|
model_name = os.getenv("WHISPER_MODEL", "large-v3")
|
|
device = choose_device()
|
|
MODEL = WhisperModel(
|
|
model_name,
|
|
device=device,
|
|
compute_type=default_compute_type(device),
|
|
download_root=os.getenv("WHISPER_MODEL_DIR") or None,
|
|
)
|
|
return MODEL
|
|
|
|
|
|
def get_whisperx_model():
|
|
global WHISPERX_MODEL
|
|
if WHISPERX_MODEL is not None:
|
|
return WHISPERX_MODEL
|
|
|
|
if whisperx is None:
|
|
raise RuntimeError("whisperx is not installed")
|
|
|
|
with WHISPERX_LOCK:
|
|
if WHISPERX_MODEL is None:
|
|
device = whisperx_device()
|
|
model_name = os.getenv("WHISPERX_MODEL", os.getenv("WHISPER_MODEL", "large-v3"))
|
|
kwargs = {
|
|
"compute_type": whisperx_compute_type(device),
|
|
"download_root": os.getenv("WHISPER_MODEL_DIR") or None,
|
|
}
|
|
try:
|
|
WHISPERX_MODEL = whisperx.load_model(model_name, device, **kwargs)
|
|
except TypeError:
|
|
kwargs.pop("download_root", None)
|
|
WHISPERX_MODEL = whisperx.load_model(model_name, device, **kwargs)
|
|
return WHISPERX_MODEL
|
|
|
|
|
|
def get_align_model(language_code: str | None, device: str):
|
|
if whisperx is None:
|
|
raise RuntimeError("whisperx is not installed")
|
|
|
|
language = language_code or os.getenv("WHISPERX_ALIGN_LANGUAGE", "en")
|
|
key = (language, device)
|
|
if key not in ALIGN_MODELS:
|
|
ALIGN_MODELS[key] = whisperx.load_align_model(language_code=language, device=device)
|
|
return ALIGN_MODELS[key]
|
|
|
|
|
|
def get_diarization_pipeline(device: str):
|
|
global DIARIZATION_PIPELINE
|
|
if whisperx is None:
|
|
raise RuntimeError("whisperx is not installed")
|
|
if DIARIZATION_PIPELINE is not None:
|
|
return DIARIZATION_PIPELINE
|
|
|
|
token = huggingface_token()
|
|
if not token:
|
|
raise RuntimeError("HuggingFace token is required for speaker diarization")
|
|
|
|
with WHISPERX_LOCK:
|
|
if DIARIZATION_PIPELINE is None:
|
|
pipeline_factory = getattr(whisperx, "DiarizationPipeline", None)
|
|
if pipeline_factory is None and whisperx_diarize is not None:
|
|
pipeline_factory = getattr(whisperx_diarize, "DiarizationPipeline", None)
|
|
if pipeline_factory is None:
|
|
raise RuntimeError("WhisperX diarization pipeline is not available")
|
|
parameters = inspect.signature(pipeline_factory).parameters
|
|
if "use_auth_token" in parameters:
|
|
DIARIZATION_PIPELINE = pipeline_factory(use_auth_token=token, device=device)
|
|
elif "auth_token" in parameters:
|
|
DIARIZATION_PIPELINE = pipeline_factory(auth_token=token, device=device)
|
|
elif "token" in parameters:
|
|
DIARIZATION_PIPELINE = pipeline_factory(
|
|
model_name=os.getenv("WHISPERX_DIARIZATION_MODEL") or None,
|
|
token=token,
|
|
device=device,
|
|
)
|
|
else:
|
|
DIARIZATION_PIPELINE = pipeline_factory(device=device)
|
|
return DIARIZATION_PIPELINE
|
|
|
|
|
|
def cleanup_gpu_memory() -> None:
|
|
gc.collect()
|
|
if torch is not None and getattr(torch, "cuda", None) is not None:
|
|
try:
|
|
if torch.cuda.is_available():
|
|
torch.cuda.empty_cache()
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
def normalize_speaker(raw: str | None, mapping: dict[str, str]) -> str:
|
|
if not raw:
|
|
return "Speaker 1"
|
|
key = str(raw)
|
|
if key not in mapping:
|
|
mapping[key] = f"Speaker {len(mapping) + 1}"
|
|
return mapping[key]
|
|
|
|
|
|
def segment_speaker(segment: dict) -> str | None:
|
|
if segment.get("speaker"):
|
|
return segment.get("speaker")
|
|
counts: dict[str, int] = {}
|
|
for word in segment.get("words") or []:
|
|
speaker = word.get("speaker")
|
|
if speaker:
|
|
counts[speaker] = counts.get(speaker, 0) + 1
|
|
if not counts:
|
|
return None
|
|
return max(counts.items(), key=lambda item: item[1])[0]
|
|
|
|
|
|
def normalize_segments(segments: list[dict], include_speakers: bool = True) -> list[dict]:
|
|
speaker_map: dict[str, str] = {}
|
|
normalized = []
|
|
for segment in segments:
|
|
text = str(segment.get("text") or "").strip()
|
|
if not text:
|
|
continue
|
|
item = {
|
|
"start": round(float(segment.get("start") or 0), 3),
|
|
"end": round(float(segment.get("end") or 0), 3),
|
|
"text": text,
|
|
}
|
|
if include_speakers:
|
|
item["speaker"] = normalize_speaker(segment_speaker(segment), speaker_map)
|
|
normalized.append(item)
|
|
return normalized
|
|
|
|
|
|
def transcribe_with_whisperx(audio_path: str) -> dict:
|
|
device = whisperx_device()
|
|
model = get_whisperx_model()
|
|
audio = whisperx.load_audio(audio_path)
|
|
batch_size = int(os.getenv("WHISPERX_BATCH_SIZE", os.getenv("WHISPER_BATCH_SIZE", "8")))
|
|
result = model.transcribe(audio, batch_size=batch_size)
|
|
|
|
language = result.get("language")
|
|
try:
|
|
align_model, metadata = get_align_model(language, device)
|
|
result = whisperx.align(
|
|
result.get("segments", []),
|
|
align_model,
|
|
metadata,
|
|
audio,
|
|
device,
|
|
return_char_alignments=False,
|
|
)
|
|
except Exception:
|
|
traceback.print_exc()
|
|
|
|
diarization_failed = False
|
|
try:
|
|
diarize_model = get_diarization_pipeline(device)
|
|
min_speakers = os.getenv("WHISPERX_MIN_SPEAKERS")
|
|
max_speakers = os.getenv("WHISPERX_MAX_SPEAKERS")
|
|
kwargs = {}
|
|
if min_speakers:
|
|
kwargs["min_speakers"] = int(min_speakers)
|
|
if max_speakers:
|
|
kwargs["max_speakers"] = int(max_speakers)
|
|
diarize_segments = diarize_model(audio, **kwargs)
|
|
result = whisperx.assign_word_speakers(diarize_segments, result)
|
|
except Exception:
|
|
diarization_failed = True
|
|
traceback.print_exc()
|
|
|
|
timestamps = normalize_segments(result.get("segments", []), include_speakers=not diarization_failed)
|
|
transcript_text = " ".join(segment["text"] for segment in timestamps if segment.get("text"))
|
|
return {
|
|
"transcript_text": transcript_text,
|
|
"language": language,
|
|
"duration": float(audio.shape[0]) / 16000 if hasattr(audio, "shape") else None,
|
|
"timestamps": timestamps,
|
|
"diarization": "fallback" if diarization_failed else "completed",
|
|
}
|
|
|
|
|
|
def transcribe_with_faster_whisper(audio_path: str) -> dict:
|
|
model = get_model()
|
|
segments, info = model.transcribe(
|
|
audio_path,
|
|
beam_size=int(os.getenv("WHISPER_BEAM_SIZE", "5")),
|
|
vad_filter=os.getenv("WHISPER_VAD_FILTER", "true").lower() in {"1", "true", "yes", "on"},
|
|
)
|
|
timestamps = []
|
|
transcript_parts = []
|
|
for segment in segments:
|
|
text = segment.text.strip()
|
|
transcript_parts.append(text)
|
|
timestamps.append({"start": segment.start, "end": segment.end, "text": text})
|
|
|
|
return {
|
|
"transcript_text": " ".join(part for part in transcript_parts if part),
|
|
"language": getattr(info, "language", None),
|
|
"duration": getattr(info, "duration", None),
|
|
"timestamps": timestamps,
|
|
"diarization": "disabled",
|
|
}
|
|
|
|
|
|
class WhisperHandler(BaseHTTPRequestHandler):
|
|
server_version = "OrphionWhisper/1.0"
|
|
|
|
def do_GET(self) -> None:
|
|
if self.path.rstrip("/") != "/health":
|
|
self.send_json({"error": "not found"}, HTTPStatus.NOT_FOUND)
|
|
return
|
|
|
|
self.send_json(
|
|
{
|
|
"status": "ok",
|
|
"model": os.getenv("WHISPER_MODEL", "large-v3"),
|
|
"device": choose_device(),
|
|
"whisperx": whisperx is not None,
|
|
"diarization": diarization_enabled(),
|
|
"diarization_ready": whisperx is not None
|
|
and diarization_enabled()
|
|
and bool(huggingface_token()),
|
|
},
|
|
)
|
|
|
|
def do_POST(self) -> None:
|
|
if self.path.rstrip("/") != "/transcribe":
|
|
self.send_json({"error": "not found"}, HTTPStatus.NOT_FOUND)
|
|
return
|
|
|
|
try:
|
|
self.handle_transcribe()
|
|
except Exception as exc:
|
|
traceback.print_exc()
|
|
self.send_json({"error": str(exc)}, HTTPStatus.INTERNAL_SERVER_ERROR)
|
|
|
|
def handle_transcribe(self) -> None:
|
|
content_type = self.headers.get("content-type", "")
|
|
if not content_type.startswith("multipart/form-data"):
|
|
self.send_json({"error": "multipart/form-data required"}, HTTPStatus.BAD_REQUEST)
|
|
return
|
|
|
|
form = cgi.FieldStorage(
|
|
fp=self.rfile,
|
|
headers=self.headers,
|
|
environ={
|
|
"REQUEST_METHOD": "POST",
|
|
"CONTENT_TYPE": content_type,
|
|
"CONTENT_LENGTH": self.headers.get("content-length", "0"),
|
|
},
|
|
)
|
|
field_name = os.getenv("WHISPER_FILE_FIELD", "file")
|
|
upload = form[field_name] if field_name in form else None
|
|
if upload is None or not getattr(upload, "file", None):
|
|
self.send_json({"error": f"missing multipart field '{field_name}'"}, HTTPStatus.BAD_REQUEST)
|
|
return
|
|
|
|
suffix = Path(getattr(upload, "filename", "") or "audio").suffix
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file:
|
|
temp_path = temp_file.name
|
|
while True:
|
|
chunk = upload.file.read(1024 * 1024)
|
|
if not chunk:
|
|
break
|
|
temp_file.write(chunk)
|
|
|
|
try:
|
|
with INFERENCE_LOCK:
|
|
try:
|
|
if diarization_enabled():
|
|
payload = transcribe_with_whisperx(temp_path)
|
|
else:
|
|
payload = transcribe_with_faster_whisper(temp_path)
|
|
except Exception:
|
|
traceback.print_exc()
|
|
payload = transcribe_with_faster_whisper(temp_path)
|
|
payload["diarization"] = "fallback"
|
|
|
|
self.send_json(payload)
|
|
finally:
|
|
cleanup_gpu_memory()
|
|
try:
|
|
os.unlink(temp_path)
|
|
except OSError:
|
|
pass
|
|
|
|
def log_message(self, fmt: str, *args: object) -> None:
|
|
print(f"{self.address_string()} - {fmt % args}", flush=True)
|
|
|
|
def send_json(self, payload: dict, status: HTTPStatus = HTTPStatus.OK) -> None:
|
|
body = json.dumps(payload).encode("utf-8")
|
|
self.send_response(status)
|
|
self.send_header("content-type", "application/json")
|
|
self.send_header("content-length", str(len(body)))
|
|
self.end_headers()
|
|
self.wfile.write(body)
|
|
|
|
|
|
def main() -> None:
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--host", default=os.getenv("WHISPER_HOST", "0.0.0.0"))
|
|
parser.add_argument("--port", default=int(os.getenv("WHISPER_PORT", "8000")), type=int)
|
|
args = parser.parse_args()
|
|
|
|
server = ThreadingHTTPServer((args.host, args.port), WhisperHandler)
|
|
print(f"Whisper API listening on http://{args.host}:{args.port}", flush=True)
|
|
server.serve_forever()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|