- point MySQL and Whisper settings to the VM - add VM MySQL bootstrap scripts and docs - allow LAN Vite origins for CORS - fix Express 5 validation assignment crash - allow login with username or email - prevent recursive auth refresh retries
184 lines
5.7 KiB
Python
184 lines
5.7 KiB
Python
#!/usr/bin/env python3
|
|
"""Tiny Faster-Whisper HTTP API for Orphion.
|
|
|
|
Endpoints:
|
|
GET /health
|
|
POST /transcribe multipart form field "file"
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import cgi
|
|
import json
|
|
import os
|
|
import tempfile
|
|
import threading
|
|
import traceback
|
|
from http import HTTPStatus
|
|
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
|
from pathlib import Path
|
|
|
|
from faster_whisper import WhisperModel
|
|
|
|
try:
|
|
import ctranslate2
|
|
except Exception: # pragma: no cover - only used for runtime device selection
|
|
ctranslate2 = None
|
|
|
|
|
|
MODEL_LOCK = threading.Lock()
|
|
MODEL = None
|
|
|
|
|
|
def choose_device() -> str:
|
|
configured = os.getenv("WHISPER_DEVICE", "").strip()
|
|
if configured:
|
|
return configured
|
|
|
|
if ctranslate2 is not None:
|
|
try:
|
|
if ctranslate2.get_cuda_device_count() > 0:
|
|
return "cuda"
|
|
except Exception:
|
|
pass
|
|
|
|
return "cpu"
|
|
|
|
|
|
def default_compute_type(device: str) -> str:
|
|
configured = os.getenv("WHISPER_COMPUTE_TYPE", "").strip()
|
|
if configured:
|
|
return configured
|
|
return "float16" if device == "cuda" else "int8"
|
|
|
|
|
|
def get_model() -> WhisperModel:
|
|
global MODEL
|
|
if MODEL is not None:
|
|
return MODEL
|
|
|
|
with MODEL_LOCK:
|
|
if MODEL is None:
|
|
model_name = os.getenv("WHISPER_MODEL", "large-v3")
|
|
device = choose_device()
|
|
MODEL = WhisperModel(
|
|
model_name,
|
|
device=device,
|
|
compute_type=default_compute_type(device),
|
|
download_root=os.getenv("WHISPER_MODEL_DIR") or None,
|
|
)
|
|
return MODEL
|
|
|
|
|
|
class WhisperHandler(BaseHTTPRequestHandler):
|
|
server_version = "OrphionWhisper/1.0"
|
|
|
|
def do_GET(self) -> None:
|
|
if self.path.rstrip("/") != "/health":
|
|
self.send_json({"error": "not found"}, HTTPStatus.NOT_FOUND)
|
|
return
|
|
|
|
self.send_json(
|
|
{
|
|
"status": "ok",
|
|
"model": os.getenv("WHISPER_MODEL", "large-v3"),
|
|
"device": choose_device(),
|
|
},
|
|
)
|
|
|
|
def do_POST(self) -> None:
|
|
if self.path.rstrip("/") != "/transcribe":
|
|
self.send_json({"error": "not found"}, HTTPStatus.NOT_FOUND)
|
|
return
|
|
|
|
try:
|
|
self.handle_transcribe()
|
|
except Exception as exc:
|
|
traceback.print_exc()
|
|
self.send_json({"error": str(exc)}, HTTPStatus.INTERNAL_SERVER_ERROR)
|
|
|
|
def handle_transcribe(self) -> None:
|
|
content_type = self.headers.get("content-type", "")
|
|
if not content_type.startswith("multipart/form-data"):
|
|
self.send_json({"error": "multipart/form-data required"}, HTTPStatus.BAD_REQUEST)
|
|
return
|
|
|
|
form = cgi.FieldStorage(
|
|
fp=self.rfile,
|
|
headers=self.headers,
|
|
environ={
|
|
"REQUEST_METHOD": "POST",
|
|
"CONTENT_TYPE": content_type,
|
|
"CONTENT_LENGTH": self.headers.get("content-length", "0"),
|
|
},
|
|
)
|
|
field_name = os.getenv("WHISPER_FILE_FIELD", "file")
|
|
upload = form[field_name] if field_name in form else None
|
|
if upload is None or not getattr(upload, "file", None):
|
|
self.send_json({"error": f"missing multipart field '{field_name}'"}, HTTPStatus.BAD_REQUEST)
|
|
return
|
|
|
|
suffix = Path(getattr(upload, "filename", "") or "audio").suffix
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file:
|
|
temp_path = temp_file.name
|
|
while True:
|
|
chunk = upload.file.read(1024 * 1024)
|
|
if not chunk:
|
|
break
|
|
temp_file.write(chunk)
|
|
|
|
try:
|
|
model = get_model()
|
|
segments, info = model.transcribe(
|
|
temp_path,
|
|
beam_size=int(os.getenv("WHISPER_BEAM_SIZE", "5")),
|
|
vad_filter=os.getenv("WHISPER_VAD_FILTER", "true").lower() in {"1", "true", "yes", "on"},
|
|
)
|
|
timestamps = []
|
|
transcript_parts = []
|
|
for segment in segments:
|
|
text = segment.text.strip()
|
|
transcript_parts.append(text)
|
|
timestamps.append({"start": segment.start, "end": segment.end, "text": text})
|
|
|
|
self.send_json(
|
|
{
|
|
"transcript_text": " ".join(part for part in transcript_parts if part),
|
|
"language": getattr(info, "language", None),
|
|
"duration": getattr(info, "duration", None),
|
|
"timestamps": timestamps,
|
|
},
|
|
)
|
|
finally:
|
|
try:
|
|
os.unlink(temp_path)
|
|
except OSError:
|
|
pass
|
|
|
|
def log_message(self, fmt: str, *args: object) -> None:
|
|
print(f"{self.address_string()} - {fmt % args}", flush=True)
|
|
|
|
def send_json(self, payload: dict, status: HTTPStatus = HTTPStatus.OK) -> None:
|
|
body = json.dumps(payload).encode("utf-8")
|
|
self.send_response(status)
|
|
self.send_header("content-type", "application/json")
|
|
self.send_header("content-length", str(len(body)))
|
|
self.end_headers()
|
|
self.wfile.write(body)
|
|
|
|
|
|
def main() -> None:
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--host", default=os.getenv("WHISPER_HOST", "0.0.0.0"))
|
|
parser.add_argument("--port", default=int(os.getenv("WHISPER_PORT", "8000")), type=int)
|
|
args = parser.parse_args()
|
|
|
|
server = ThreadingHTTPServer((args.host, args.port), WhisperHandler)
|
|
print(f"Whisper API listening on http://{args.host}:{args.port}", flush=True)
|
|
server.serve_forever()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|