Mom-Portal/scripts/whisper_http_server.py
KevinB-T 30894e7f27 Configure VM services and fix auth flow
- point MySQL and Whisper settings to the VM
- add VM MySQL bootstrap scripts and docs
- allow LAN Vite origins for CORS
- fix Express 5 validation assignment crash
- allow login with username or email
- prevent recursive auth refresh retries
2026-05-15 15:19:49 +05:30

184 lines
5.7 KiB
Python

#!/usr/bin/env python3
"""Tiny Faster-Whisper HTTP API for Orphion.
Endpoints:
GET /health
POST /transcribe multipart form field "file"
"""
from __future__ import annotations
import argparse
import cgi
import json
import os
import tempfile
import threading
import traceback
from http import HTTPStatus
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from pathlib import Path
from faster_whisper import WhisperModel
try:
import ctranslate2
except Exception: # pragma: no cover - only used for runtime device selection
ctranslate2 = None
MODEL_LOCK = threading.Lock()
MODEL = None
def choose_device() -> str:
configured = os.getenv("WHISPER_DEVICE", "").strip()
if configured:
return configured
if ctranslate2 is not None:
try:
if ctranslate2.get_cuda_device_count() > 0:
return "cuda"
except Exception:
pass
return "cpu"
def default_compute_type(device: str) -> str:
configured = os.getenv("WHISPER_COMPUTE_TYPE", "").strip()
if configured:
return configured
return "float16" if device == "cuda" else "int8"
def get_model() -> WhisperModel:
global MODEL
if MODEL is not None:
return MODEL
with MODEL_LOCK:
if MODEL is None:
model_name = os.getenv("WHISPER_MODEL", "large-v3")
device = choose_device()
MODEL = WhisperModel(
model_name,
device=device,
compute_type=default_compute_type(device),
download_root=os.getenv("WHISPER_MODEL_DIR") or None,
)
return MODEL
class WhisperHandler(BaseHTTPRequestHandler):
server_version = "OrphionWhisper/1.0"
def do_GET(self) -> None:
if self.path.rstrip("/") != "/health":
self.send_json({"error": "not found"}, HTTPStatus.NOT_FOUND)
return
self.send_json(
{
"status": "ok",
"model": os.getenv("WHISPER_MODEL", "large-v3"),
"device": choose_device(),
},
)
def do_POST(self) -> None:
if self.path.rstrip("/") != "/transcribe":
self.send_json({"error": "not found"}, HTTPStatus.NOT_FOUND)
return
try:
self.handle_transcribe()
except Exception as exc:
traceback.print_exc()
self.send_json({"error": str(exc)}, HTTPStatus.INTERNAL_SERVER_ERROR)
def handle_transcribe(self) -> None:
content_type = self.headers.get("content-type", "")
if not content_type.startswith("multipart/form-data"):
self.send_json({"error": "multipart/form-data required"}, HTTPStatus.BAD_REQUEST)
return
form = cgi.FieldStorage(
fp=self.rfile,
headers=self.headers,
environ={
"REQUEST_METHOD": "POST",
"CONTENT_TYPE": content_type,
"CONTENT_LENGTH": self.headers.get("content-length", "0"),
},
)
field_name = os.getenv("WHISPER_FILE_FIELD", "file")
upload = form[field_name] if field_name in form else None
if upload is None or not getattr(upload, "file", None):
self.send_json({"error": f"missing multipart field '{field_name}'"}, HTTPStatus.BAD_REQUEST)
return
suffix = Path(getattr(upload, "filename", "") or "audio").suffix
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file:
temp_path = temp_file.name
while True:
chunk = upload.file.read(1024 * 1024)
if not chunk:
break
temp_file.write(chunk)
try:
model = get_model()
segments, info = model.transcribe(
temp_path,
beam_size=int(os.getenv("WHISPER_BEAM_SIZE", "5")),
vad_filter=os.getenv("WHISPER_VAD_FILTER", "true").lower() in {"1", "true", "yes", "on"},
)
timestamps = []
transcript_parts = []
for segment in segments:
text = segment.text.strip()
transcript_parts.append(text)
timestamps.append({"start": segment.start, "end": segment.end, "text": text})
self.send_json(
{
"transcript_text": " ".join(part for part in transcript_parts if part),
"language": getattr(info, "language", None),
"duration": getattr(info, "duration", None),
"timestamps": timestamps,
},
)
finally:
try:
os.unlink(temp_path)
except OSError:
pass
def log_message(self, fmt: str, *args: object) -> None:
print(f"{self.address_string()} - {fmt % args}", flush=True)
def send_json(self, payload: dict, status: HTTPStatus = HTTPStatus.OK) -> None:
body = json.dumps(payload).encode("utf-8")
self.send_response(status)
self.send_header("content-type", "application/json")
self.send_header("content-length", str(len(body)))
self.end_headers()
self.wfile.write(body)
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--host", default=os.getenv("WHISPER_HOST", "0.0.0.0"))
parser.add_argument("--port", default=int(os.getenv("WHISPER_PORT", "8000")), type=int)
args = parser.parse_args()
server = ThreadingHTTPServer((args.host, args.port), WhisperHandler)
print(f"Whisper API listening on http://{args.host}:{args.port}", flush=True)
server.serve_forever()
if __name__ == "__main__":
main()