#!/usr/bin/env python3 """Tiny Faster-Whisper/WhisperX HTTP API for Orphion. Endpoints: GET /health POST /transcribe multipart form field "file" """ from __future__ import annotations import argparse import cgi import gc import inspect import json import os import tempfile import threading import traceback from http import HTTPStatus from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer from pathlib import Path from faster_whisper import WhisperModel try: import torch except Exception: # pragma: no cover - torch is optional for CPU fallback torch = None try: import whisperx except Exception: # pragma: no cover - keeps existing Faster-Whisper path available whisperx = None try: import whisperx.diarize as whisperx_diarize except Exception: # pragma: no cover - older/newer WhisperX layouts vary whisperx_diarize = None try: import ctranslate2 except Exception: # pragma: no cover - only used for runtime device selection ctranslate2 = None MODEL_LOCK = threading.Lock() MODEL = None WHISPERX_LOCK = threading.Lock() INFERENCE_LOCK = threading.Lock() WHISPERX_MODEL = None ALIGN_MODELS = {} DIARIZATION_PIPELINE = None def choose_device() -> str: configured = os.getenv("WHISPER_DEVICE", "").strip() if configured: return configured if ctranslate2 is not None: try: if ctranslate2.get_cuda_device_count() > 0: return "cuda" except Exception: pass return "cpu" def default_compute_type(device: str) -> str: configured = os.getenv("WHISPER_COMPUTE_TYPE", "").strip() if configured: return configured return "float16" if device.startswith("cuda") else "int8" def diarization_enabled() -> bool: return os.getenv("WHISPERX_DIARIZATION", "true").lower() in {"1", "true", "yes", "on"} def whisperx_device() -> str: configured = os.getenv("WHISPERX_DEVICE", "").strip() if configured: return configured selected = choose_device() return "cuda" if selected == "cuda" else selected def whisperx_compute_type(device: str) -> str: configured = os.getenv("WHISPERX_COMPUTE_TYPE", "").strip() if configured: return configured return "float16" if device.startswith("cuda") else default_compute_type(device) def huggingface_token() -> str | None: return ( os.getenv("HUGGINGFACE_TOKEN") or os.getenv("HF_TOKEN") or os.getenv("PYANNOTE_AUTH_TOKEN") or None ) def get_model() -> WhisperModel: global MODEL if MODEL is not None: return MODEL with MODEL_LOCK: if MODEL is None: model_name = os.getenv("WHISPER_MODEL", "large-v3") device = choose_device() MODEL = WhisperModel( model_name, device=device, compute_type=default_compute_type(device), download_root=os.getenv("WHISPER_MODEL_DIR") or None, ) return MODEL def get_whisperx_model(): global WHISPERX_MODEL if WHISPERX_MODEL is not None: return WHISPERX_MODEL if whisperx is None: raise RuntimeError("whisperx is not installed") with WHISPERX_LOCK: if WHISPERX_MODEL is None: device = whisperx_device() model_name = os.getenv("WHISPERX_MODEL", os.getenv("WHISPER_MODEL", "large-v3")) kwargs = { "compute_type": whisperx_compute_type(device), "download_root": os.getenv("WHISPER_MODEL_DIR") or None, } try: WHISPERX_MODEL = whisperx.load_model(model_name, device, **kwargs) except TypeError: kwargs.pop("download_root", None) WHISPERX_MODEL = whisperx.load_model(model_name, device, **kwargs) return WHISPERX_MODEL def get_align_model(language_code: str | None, device: str): if whisperx is None: raise RuntimeError("whisperx is not installed") language = language_code or os.getenv("WHISPERX_ALIGN_LANGUAGE", "en") key = (language, device) if key not in ALIGN_MODELS: ALIGN_MODELS[key] = whisperx.load_align_model(language_code=language, device=device) return ALIGN_MODELS[key] def get_diarization_pipeline(device: str): global DIARIZATION_PIPELINE if whisperx is None: raise RuntimeError("whisperx is not installed") if DIARIZATION_PIPELINE is not None: return DIARIZATION_PIPELINE token = huggingface_token() if not token: raise RuntimeError("HuggingFace token is required for speaker diarization") with WHISPERX_LOCK: if DIARIZATION_PIPELINE is None: pipeline_factory = getattr(whisperx, "DiarizationPipeline", None) if pipeline_factory is None and whisperx_diarize is not None: pipeline_factory = getattr(whisperx_diarize, "DiarizationPipeline", None) if pipeline_factory is None: raise RuntimeError("WhisperX diarization pipeline is not available") parameters = inspect.signature(pipeline_factory).parameters if "use_auth_token" in parameters: DIARIZATION_PIPELINE = pipeline_factory(use_auth_token=token, device=device) elif "auth_token" in parameters: DIARIZATION_PIPELINE = pipeline_factory(auth_token=token, device=device) elif "token" in parameters: DIARIZATION_PIPELINE = pipeline_factory( model_name=os.getenv("WHISPERX_DIARIZATION_MODEL") or None, token=token, device=device, ) else: DIARIZATION_PIPELINE = pipeline_factory(device=device) return DIARIZATION_PIPELINE def cleanup_gpu_memory() -> None: gc.collect() if torch is not None and getattr(torch, "cuda", None) is not None: try: if torch.cuda.is_available(): torch.cuda.empty_cache() except Exception: pass def normalize_speaker(raw: str | None, mapping: dict[str, str]) -> str: if not raw: return "Speaker 1" key = str(raw) if key not in mapping: mapping[key] = f"Speaker {len(mapping) + 1}" return mapping[key] def segment_speaker(segment: dict) -> str | None: if segment.get("speaker"): return segment.get("speaker") counts: dict[str, int] = {} for word in segment.get("words") or []: speaker = word.get("speaker") if speaker: counts[speaker] = counts.get(speaker, 0) + 1 if not counts: return None return max(counts.items(), key=lambda item: item[1])[0] def normalize_segments(segments: list[dict], include_speakers: bool = True) -> list[dict]: speaker_map: dict[str, str] = {} normalized = [] for segment in segments: text = str(segment.get("text") or "").strip() if not text: continue item = { "start": round(float(segment.get("start") or 0), 3), "end": round(float(segment.get("end") or 0), 3), "text": text, } if include_speakers: item["speaker"] = normalize_speaker(segment_speaker(segment), speaker_map) normalized.append(item) return normalized def transcribe_with_whisperx(audio_path: str) -> dict: device = whisperx_device() model = get_whisperx_model() audio = whisperx.load_audio(audio_path) batch_size = int(os.getenv("WHISPERX_BATCH_SIZE", os.getenv("WHISPER_BATCH_SIZE", "8"))) result = model.transcribe(audio, batch_size=batch_size) language = result.get("language") try: align_model, metadata = get_align_model(language, device) result = whisperx.align( result.get("segments", []), align_model, metadata, audio, device, return_char_alignments=False, ) except Exception: traceback.print_exc() diarization_failed = False try: diarize_model = get_diarization_pipeline(device) min_speakers = os.getenv("WHISPERX_MIN_SPEAKERS") max_speakers = os.getenv("WHISPERX_MAX_SPEAKERS") kwargs = {} if min_speakers: kwargs["min_speakers"] = int(min_speakers) if max_speakers: kwargs["max_speakers"] = int(max_speakers) diarize_segments = diarize_model(audio, **kwargs) result = whisperx.assign_word_speakers(diarize_segments, result) except Exception: diarization_failed = True traceback.print_exc() timestamps = normalize_segments(result.get("segments", []), include_speakers=not diarization_failed) transcript_text = " ".join(segment["text"] for segment in timestamps if segment.get("text")) return { "transcript_text": transcript_text, "language": language, "duration": float(audio.shape[0]) / 16000 if hasattr(audio, "shape") else None, "timestamps": timestamps, "diarization": "fallback" if diarization_failed else "completed", } def transcribe_with_faster_whisper(audio_path: str) -> dict: model = get_model() segments, info = model.transcribe( audio_path, beam_size=int(os.getenv("WHISPER_BEAM_SIZE", "5")), vad_filter=os.getenv("WHISPER_VAD_FILTER", "true").lower() in {"1", "true", "yes", "on"}, ) timestamps = [] transcript_parts = [] for segment in segments: text = segment.text.strip() transcript_parts.append(text) timestamps.append({"start": segment.start, "end": segment.end, "text": text}) return { "transcript_text": " ".join(part for part in transcript_parts if part), "language": getattr(info, "language", None), "duration": getattr(info, "duration", None), "timestamps": timestamps, "diarization": "disabled", } class WhisperHandler(BaseHTTPRequestHandler): server_version = "OrphionWhisper/1.0" def do_GET(self) -> None: if self.path.rstrip("/") != "/health": self.send_json({"error": "not found"}, HTTPStatus.NOT_FOUND) return self.send_json( { "status": "ok", "model": os.getenv("WHISPER_MODEL", "large-v3"), "device": choose_device(), "whisperx": whisperx is not None, "diarization": diarization_enabled(), "diarization_ready": whisperx is not None and diarization_enabled() and bool(huggingface_token()), }, ) def do_POST(self) -> None: if self.path.rstrip("/") != "/transcribe": self.send_json({"error": "not found"}, HTTPStatus.NOT_FOUND) return try: self.handle_transcribe() except Exception as exc: traceback.print_exc() self.send_json({"error": str(exc)}, HTTPStatus.INTERNAL_SERVER_ERROR) def handle_transcribe(self) -> None: content_type = self.headers.get("content-type", "") if not content_type.startswith("multipart/form-data"): self.send_json({"error": "multipart/form-data required"}, HTTPStatus.BAD_REQUEST) return form = cgi.FieldStorage( fp=self.rfile, headers=self.headers, environ={ "REQUEST_METHOD": "POST", "CONTENT_TYPE": content_type, "CONTENT_LENGTH": self.headers.get("content-length", "0"), }, ) field_name = os.getenv("WHISPER_FILE_FIELD", "file") upload = form[field_name] if field_name in form else None if upload is None or not getattr(upload, "file", None): self.send_json({"error": f"missing multipart field '{field_name}'"}, HTTPStatus.BAD_REQUEST) return suffix = Path(getattr(upload, "filename", "") or "audio").suffix with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file: temp_path = temp_file.name while True: chunk = upload.file.read(1024 * 1024) if not chunk: break temp_file.write(chunk) try: with INFERENCE_LOCK: try: if diarization_enabled(): payload = transcribe_with_whisperx(temp_path) else: payload = transcribe_with_faster_whisper(temp_path) except Exception: traceback.print_exc() payload = transcribe_with_faster_whisper(temp_path) payload["diarization"] = "fallback" self.send_json(payload) finally: cleanup_gpu_memory() try: os.unlink(temp_path) except OSError: pass def log_message(self, fmt: str, *args: object) -> None: print(f"{self.address_string()} - {fmt % args}", flush=True) def send_json(self, payload: dict, status: HTTPStatus = HTTPStatus.OK) -> None: body = json.dumps(payload).encode("utf-8") self.send_response(status) self.send_header("content-type", "application/json") self.send_header("content-length", str(len(body))) self.end_headers() self.wfile.write(body) def main() -> None: parser = argparse.ArgumentParser() parser.add_argument("--host", default=os.getenv("WHISPER_HOST", "0.0.0.0")) parser.add_argument("--port", default=int(os.getenv("WHISPER_PORT", "8000")), type=int) args = parser.parse_args() server = ThreadingHTTPServer((args.host, args.port), WhisperHandler) print(f"Whisper API listening on http://{args.host}:{args.port}", flush=True) server.serve_forever() if __name__ == "__main__": main()