feat: add WhisperX diarization and speaker transcript UI

- add WhisperX diarization support to the Whisper VM server
- normalize speaker timestamp segments from Whisper responses
- document Hugging Face/pyannote VM setup and health checks
- show diarized speaker transcript blocks in record and transcript views
- group consecutive segments from the same speaker
- remove duplicate paragraph transcript display when diarized segments exist
- let diarized transcript content expand without an inner scrollbar
This commit is contained in:
KevinB-T 2026-05-20 16:34:50 +05:30
parent 3abb5e9281
commit 9517bad3dc
10 changed files with 674 additions and 73 deletions

View File

@ -6,6 +6,15 @@ import { env } from "./env.js";
import { pool, query } from "./database.js";
const __dirname = path.dirname(fileURLToPath(import.meta.url));
const requiredTables = [
"users",
"refresh_tokens",
"audio_assets",
"transcripts",
"transcript_shares",
"audio_metadata",
"transcription_jobs",
];
function identifier(value) {
return `\`${String(value).replaceAll("`", "``")}\``;
@ -38,12 +47,24 @@ export async function ensureDatabase() {
}
export async function ensureSchema() {
await ensureDatabase();
const schema = await fs.readFile(path.join(__dirname, "schema.sql"), "utf8");
for (const statement of splitSql(schema)) {
await pool.query(statement);
try {
await ensureDatabase();
} catch (error) {
if (!isReadOnlyError(error)) throw error;
await ensureExistingSchema();
return;
}
const schema = await fs.readFile(path.join(__dirname, "schema.sql"), "utf8");
try {
for (const statement of splitSql(schema)) {
await pool.query(statement);
}
await runMigrations();
} catch (error) {
if (!isReadOnlyError(error)) throw error;
await ensureExistingSchema();
}
await runMigrations();
}
export async function runMigrations() {
@ -85,6 +106,31 @@ export async function runMigrations() {
}
}
function isReadOnlyError(error) {
return (
error?.code === "ER_OPTION_PREVENTS_STATEMENT" ||
error?.errno === 1290 ||
String(error?.message ?? "").includes("--read-only")
);
}
async function ensureExistingSchema() {
const rows = await query(
`SELECT TABLE_NAME AS tableName
FROM information_schema.TABLES
WHERE TABLE_SCHEMA = DATABASE()
AND TABLE_NAME IN (${requiredTables.map((_, index) => `:table${index}`).join(", ")})`,
Object.fromEntries(requiredTables.map((table, index) => [`table${index}`, table])),
);
const present = new Set(rows.map((row) => row.tableName));
const missing = requiredTables.filter((table) => !present.has(table));
if (missing.length > 0) {
throw new Error(
`Database is read-only and schema is incomplete. Missing tables: ${missing.join(", ")}`,
);
}
}
if (import.meta.url === `file://${process.argv[1]}`) {
ensureSchema()
.then(() => {

View File

@ -9,7 +9,7 @@ export function securityMiddleware(app) {
app.set("trust proxy", 1);
}
const origins = env.clientOrigin.split(",").map((origin) => origin.trim());
const origins = allowedOrigins();
app.use(
helmet({
crossOriginResourcePolicy: { policy: "cross-origin" },
@ -35,6 +35,33 @@ export function securityMiddleware(app) {
);
}
function allowedOrigins() {
const configured = env.clientOrigin
.split(",")
.map((origin) => origin.trim())
.filter(Boolean);
if (env.isProduction) return configured;
const expanded = new Set(configured);
for (const origin of configured) {
try {
const url = new URL(origin);
if (url.protocol === "http:") {
url.protocol = "https:";
expanded.add(url.toString().replace(/\/$/, ""));
} else if (url.protocol === "https:") {
url.protocol = "http:";
expanded.add(url.toString().replace(/\/$/, ""));
}
} catch {
// Ignore invalid origin entries; the CORS check below will reject them.
}
}
return [...expanded];
}
export const authRateLimiter = rateLimit({
windowMs: env.rateLimit.windowMs,
limit: env.rateLimit.authMax,

View File

@ -94,10 +94,19 @@ function normalizeTranscript(payload, elapsedMs) {
throw new AppError("Whisper VM returned an empty transcript", 502, "EMPTY_TRANSCRIPT");
}
const timestamps = normalizeSegments(
payload.timestamps ??
payload.segments ??
payload.speaker_segments ??
payload.result?.timestamps ??
payload.result?.segments ??
[],
);
return {
transcriptText,
language: payload.language ?? payload.result?.language ?? null,
timestamps: payload.timestamps ?? payload.segments ?? payload.result?.timestamps ?? [],
timestamps,
duration: numberOrNull(payload.duration ?? payload.result?.duration),
processingTime: Number((elapsedMs / 1000).toFixed(3)),
modelName: env.whisper.modelName,
@ -108,13 +117,38 @@ function mockTranscript(filename) {
return {
transcriptText: `Mock transcript for ${filename}. Configure WHISPER_API_URL to use Faster-Whisper Large v3.`,
language: "en",
timestamps: [{ start: 0, end: 4, text: "Mock transcript generated for local development." }],
timestamps: [
{
speaker: "Speaker 1",
start: 0,
end: 4,
text: "Mock transcript generated for local development.",
},
],
duration: null,
processingTime: 0.15,
modelName: env.whisper.modelName,
};
}
function normalizeSegments(segments) {
if (!Array.isArray(segments)) return [];
return segments
.map((segment) => {
const text = String(segment?.text ?? "").trim();
const start = numberOrNull(segment?.start);
const end = numberOrNull(segment?.end);
if (!text || start === null || end === null) return null;
return {
...(segment.speaker ? { speaker: String(segment.speaker) } : {}),
start,
end,
text,
};
})
.filter(Boolean);
}
function numberOrNull(value) {
const number = Number(value);
return Number.isFinite(number) ? number : null;

View File

@ -16,7 +16,8 @@ WHISPER_ALLOW_MOCK=false
Expected endpoints:
- `GET /health` returns any 2xx status when the VM is ready.
- `GET /health` returns any 2xx status when the VM is ready. For WhisperX diarization, it should
also report `"whisperx": true` and `"diarization": true`.
- `POST /transcribe` accepts multipart audio and returns one of:
```json
@ -24,9 +25,49 @@ Expected endpoints:
"transcript_text": "Meeting transcript...",
"language": "en",
"duration": 123.45,
"timestamps": [{ "start": 0, "end": 5, "text": "Hello" }]
"timestamps": [{ "speaker": "Speaker 1", "start": 0, "end": 5, "text": "Hello" }]
}
```
The API retries failed requests, applies `WHISPER_TIMEOUT_MS`, and marks jobs as failed when the VM
is unavailable.
## Enable WhisperX diarization on the VM
The systemd unit runs `/home/cezen/whisper/server.py` inside the existing
`/home/cezen/whisper/venv`. Deploy the updated script without creating a new venv:
```bash
scp scripts/whisper_http_server.py cezen@172.16.10.64:/home/cezen/whisper/server.py
scp scripts/orphion-whisper.service cezen@172.16.10.64:/tmp/orphion-whisper.service
ssh cezen@172.16.10.64 'sudo mv /tmp/orphion-whisper.service /etc/systemd/system/orphion-whisper.service'
```
Create `/home/cezen/whisper/.env` on the VM with the HuggingFace token accepted by pyannote:
```env
HUGGINGFACE_TOKEN=your_token_here
WHISPERX_DIARIZATION=true
WHISPERX_DEVICE=cuda
WHISPERX_COMPUTE_TYPE=float16
WHISPERX_BATCH_SIZE=8
WHISPERX_DIARIZATION_MODEL=pyannote/speaker-diarization-community-1
```
The HuggingFace account behind the token must be approved for the configured pyannote diarization
model. If transcription returns `"diarization": "fallback"` and the service log mentions a gated
repo, visit `https://huggingface.co/pyannote/speaker-diarization-community-1` while signed in to
that account and accept/request access.
Restart and verify:
```bash
ssh cezen@172.16.10.64 'sudo systemctl daemon-reload && sudo systemctl restart orphion-whisper'
curl -sS http://172.16.10.64:8000/health
```
Expected health shape:
```json
{ "status": "ok", "model": "large-v3", "device": "cuda", "whisperx": true, "diarization": true }
```

View File

@ -4,6 +4,7 @@ import { useEffect, useMemo, useRef, useState } from "react";
import { motion } from "framer-motion";
import {
Check,
Clock,
Loader2,
Mic,
Pause,
@ -14,11 +15,12 @@ import {
Sparkles,
Square,
Trash2,
UserRound,
} from "lucide-react";
import { toast } from "sonner";
import { transcribeAudio } from "@/services/audio";
import { transcriptService } from "@/services/transcripts";
import type { Transcript } from "@/services/types";
import type { Transcript, TranscriptSegment } from "@/services/types";
import { AudioPlayer } from "@/components/audio-player";
import { SendTranscriptDialog } from "@/components/send-transcript-dialog";
@ -483,12 +485,42 @@ function RecordPage() {
</Link>
</div>
</div>
<textarea
value={transcript}
onChange={(event) => setTranscript(event.target.value)}
rows={16}
className="mt-5 w-full resize-none rounded-xl border border-border bg-background/35 p-4 text-sm leading-relaxed outline-none transition focus:border-primary/60"
/>
{currentTranscript.timestamps.length > 0 ? (
<div className="mt-5 space-y-3 rounded-xl border border-border bg-background/35 p-4">
{groupTranscriptSegments(currentTranscript.timestamps).map((segment, index) => {
const speaker = segment.speaker ?? "Speaker";
return (
<div
key={`${segment.start}-${segment.end}-${index}`}
className="rounded-lg border border-border bg-secondary/20 p-3 text-sm"
>
<div className="mb-2 flex flex-wrap items-center gap-2">
<span
className={`inline-flex items-center gap-1.5 rounded-full border px-2.5 py-1 text-xs font-semibold ${speakerTone(
speaker,
)}`}
>
<UserRound className="h-3.5 w-3.5" />
{speaker}
</span>
<span className="inline-flex items-center gap-1.5 font-mono text-xs text-primary">
<Clock className="h-3.5 w-3.5" />
{formatStamp(segment.start)} - {formatStamp(segment.end)}
</span>
</div>
<p className="leading-relaxed text-muted-foreground">{segment.text}</p>
</div>
);
})}
</div>
) : (
<textarea
value={transcript}
onChange={(event) => setTranscript(event.target.value)}
rows={16}
className="mt-5 w-full resize-none rounded-xl border border-border bg-background/35 p-4 text-sm leading-relaxed outline-none transition focus:border-primary/60"
/>
)}
</section>
)}
@ -506,3 +538,43 @@ function RecordPage() {
function formatTime(seconds: number) {
return `${String(Math.floor(seconds / 60)).padStart(2, "0")}:${String(seconds % 60).padStart(2, "0")}`;
}
function speakerTone(speaker: string) {
const tones = [
"border-primary/25 bg-primary/15 text-primary",
"border-emerald-400/25 bg-emerald-400/10 text-emerald-300",
"border-sky-400/25 bg-sky-400/10 text-sky-300",
"border-amber-400/25 bg-amber-400/10 text-amber-300",
"border-rose-400/25 bg-rose-400/10 text-rose-300",
];
const number = Number(speaker.match(/\d+/)?.[0] ?? 1);
return tones[(Math.max(number, 1) - 1) % tones.length];
}
function groupTranscriptSegments(segments: TranscriptSegment[]) {
return segments.reduce<TranscriptSegment[]>((groups, segment) => {
const previous = groups.at(-1);
const speaker = segment.speaker ?? "";
if (previous && (previous.speaker ?? "") === speaker) {
previous.end = segment.end;
previous.text = `${previous.text} ${segment.text}`.trim();
return groups;
}
groups.push({ ...segment });
return groups;
}, []);
}
function formatStamp(seconds: number) {
const safeSeconds = Number.isFinite(seconds) ? Math.max(seconds, 0) : 0;
const totalSeconds = Math.floor(safeSeconds);
const hours = Math.floor(totalSeconds / 3600);
const minutes = Math.floor((totalSeconds % 3600) / 60);
const secs = totalSeconds % 60;
if (hours > 0) {
return `${String(hours).padStart(2, "0")}:${String(minutes).padStart(2, "0")}:${String(
secs,
).padStart(2, "0")}`;
}
return `${String(minutes).padStart(2, "0")}:${String(secs).padStart(2, "0")}`;
}

View File

@ -1,11 +1,22 @@
import { createFileRoute, Link, useNavigate } from "@tanstack/react-router";
import { useQuery, useQueryClient } from "@tanstack/react-query";
import { useEffect, useRef, useState } from "react";
import { ArrowLeft, Copy, Download, Loader2, Save, Send, Sparkles, Trash2 } from "lucide-react";
import {
ArrowLeft,
Clock,
Copy,
Download,
Loader2,
Save,
Send,
Sparkles,
Trash2,
UserRound,
} from "lucide-react";
import { toast } from "sonner";
import { transcriptService } from "@/services/transcripts";
import { useAuth } from "@/context/auth";
import type { Transcript } from "@/services/types";
import type { Transcript, TranscriptSegment } from "@/services/types";
import { AudioPlayer } from "@/components/audio-player";
import { ConfirmDialog } from "@/components/confirm-dialog";
import { SendTranscriptDialog } from "@/components/send-transcript-dialog";
@ -104,7 +115,7 @@ function TranscriptDetailPage() {
}
async function copy() {
await navigator.clipboard.writeText(text);
await navigator.clipboard.writeText(formatTranscriptForClipboard(transcript));
toast.success("Transcript copied");
}
@ -277,34 +288,19 @@ function TranscriptDetailPage() {
</button>
)}
</div>
<textarea
value={text}
onChange={(event) => setText(event.target.value)}
rows={18}
readOnly={isReceived}
className="w-full resize-none bg-transparent text-sm leading-relaxed outline-none read-only:text-muted-foreground"
/>
{transcript.timestamps.length > 0 ? (
<SpeakerTranscript transcript={transcript} />
) : (
<textarea
value={text}
onChange={(event) => setText(event.target.value)}
rows={18}
readOnly={isReceived}
className="w-full resize-none bg-transparent text-sm leading-relaxed outline-none read-only:text-muted-foreground"
/>
)}
</div>
{transcript.timestamps.length > 0 && (
<div className="glass rounded-2xl p-6">
<h2 className="font-display text-lg font-semibold">Timestamps</h2>
<div className="mt-4 max-h-72 space-y-2 overflow-y-auto">
{transcript.timestamps.map((segment, index) => (
<div
key={`${segment.start}-${index}`}
className="grid gap-3 rounded-xl border border-border bg-secondary/20 p-3 text-sm md:grid-cols-[120px_1fr]"
>
<span className="font-mono text-xs text-primary">
{formatStamp(segment.start)} - {formatStamp(segment.end)}
</span>
<span className="text-muted-foreground">{segment.text}</span>
</div>
))}
</div>
</div>
)}
{showSend && (
<SendTranscriptDialog sending={sending} onClose={() => setShowSend(false)} onSend={send} />
)}
@ -322,8 +318,105 @@ function TranscriptDetailPage() {
);
}
function formatStamp(seconds: number) {
const minutes = Math.floor(seconds / 60);
const secs = Math.floor(seconds % 60);
return `${minutes}:${String(secs).padStart(2, "0")}`;
function SpeakerTranscript({ transcript }: { transcript: Transcript }) {
const hasSpeakers = transcript.timestamps.some((segment) => segment.speaker);
const groupedSegments = groupTranscriptSegments(transcript.timestamps);
return (
<div>
<div className="flex flex-wrap items-center justify-between gap-3">
<h2 className="font-display text-lg font-semibold">
{hasSpeakers ? "Speaker transcript" : "Timestamps"}
</h2>
{hasSpeakers && (
<span className="rounded-full border border-primary/20 bg-primary/10 px-3 py-1 text-xs font-medium text-primary">
{speakerCount(transcript)} {speakerCount(transcript) === 1 ? "speaker" : "speakers"}
</span>
)}
</div>
<div className="mt-4 space-y-3">
{groupedSegments.map((segment, index) => {
const speaker = segment.speaker ?? "Speaker";
return (
<div
key={`${segment.start}-${segment.end}-${index}`}
className="rounded-xl border border-border bg-secondary/20 p-4 text-sm"
>
<div className="mb-2 flex flex-wrap items-center gap-2">
{segment.speaker && (
<span
className={`inline-flex items-center gap-1.5 rounded-full border px-2.5 py-1 text-xs font-semibold ${speakerTone(
speaker,
)}`}
>
<UserRound className="h-3.5 w-3.5" />
{speaker}
</span>
)}
<span className="inline-flex items-center gap-1.5 font-mono text-xs text-primary">
<Clock className="h-3.5 w-3.5" />
{formatStamp(segment.start)} - {formatStamp(segment.end)}
</span>
</div>
<p className="leading-relaxed text-muted-foreground">{segment.text}</p>
</div>
);
})}
</div>
</div>
);
}
function formatTranscriptForClipboard(transcript: Transcript | null) {
if (!transcript?.timestamps.length) return transcript?.transcriptText ?? "";
return groupTranscriptSegments(transcript.timestamps)
.map((segment) => {
const speaker = segment.speaker ? `${segment.speaker}\n` : "";
return `${speaker}${formatStamp(segment.start)} - ${formatStamp(segment.end)}\n${segment.text}`;
})
.join("\n\n");
}
function speakerCount(transcript: Transcript) {
return new Set(transcript.timestamps.map((segment) => segment.speaker).filter(Boolean)).size;
}
function groupTranscriptSegments(segments: TranscriptSegment[]) {
return segments.reduce<TranscriptSegment[]>((groups, segment) => {
const previous = groups.at(-1);
const speaker = segment.speaker ?? "";
if (previous && (previous.speaker ?? "") === speaker) {
previous.end = segment.end;
previous.text = `${previous.text} ${segment.text}`.trim();
return groups;
}
groups.push({ ...segment });
return groups;
}, []);
}
function speakerTone(speaker: string) {
const tones = [
"border-primary/25 bg-primary/15 text-primary",
"border-emerald-400/25 bg-emerald-400/10 text-emerald-300",
"border-sky-400/25 bg-sky-400/10 text-sky-300",
"border-amber-400/25 bg-amber-400/10 text-amber-300",
"border-rose-400/25 bg-rose-400/10 text-rose-300",
];
const number = Number(speaker.match(/\d+/)?.[0] ?? 1);
return tones[(Math.max(number, 1) - 1) % tones.length];
}
function formatStamp(seconds: number) {
const safeSeconds = Number.isFinite(seconds) ? Math.max(seconds, 0) : 0;
const totalSeconds = Math.floor(safeSeconds);
const hours = Math.floor(totalSeconds / 3600);
const minutes = Math.floor((totalSeconds % 3600) / 60);
const secs = totalSeconds % 60;
if (hours > 0) {
return `${String(hours).padStart(2, "0")}:${String(minutes).padStart(2, "0")}:${String(
secs,
).padStart(2, "0")}`;
}
return `${String(minutes).padStart(2, "0")}:${String(secs).padStart(2, "0")}`;
}

View File

@ -10,6 +10,7 @@ export type User = {
export type UserSummary = Pick<User, "id" | "fullName" | "username" | "email">;
export type TranscriptSegment = {
speaker?: string;
start: number;
end: number;
text: string;

View File

@ -1,3 +1,7 @@
import { execFileSync } from "node:child_process";
import { existsSync, mkdirSync, readFileSync } from "node:fs";
import { tmpdir } from "node:os";
import path from "node:path";
import { fileURLToPath, URL } from "node:url";
import tailwindcss from "@tailwindcss/vite";
import { tanstackRouter } from "@tanstack/router-plugin/vite";
@ -5,11 +9,49 @@ import viteReact from "@vitejs/plugin-react";
import { defineConfig, loadEnv } from "vite";
import tsConfigPaths from "vite-tsconfig-paths";
export default defineConfig(({ mode }) => {
function getDevHttpsOptions() {
const certDir = path.join(tmpdir(), "orphion-vite-https");
const keyPath = path.join(certDir, "localhost-key.pem");
const certPath = path.join(certDir, "localhost-cert.pem");
if (!existsSync(keyPath) || !existsSync(certPath)) {
mkdirSync(certDir, { recursive: true });
execFileSync(
"openssl",
[
"req",
"-x509",
"-newkey",
"rsa:2048",
"-nodes",
"-sha256",
"-days",
"365",
"-subj",
"/CN=localhost",
"-addext",
"subjectAltName=DNS:localhost,IP:127.0.0.1,IP:0.0.0.0",
"-keyout",
keyPath,
"-out",
certPath,
],
{ stdio: "ignore" },
);
}
return {
key: readFileSync(keyPath),
cert: readFileSync(certPath),
};
}
export default defineConfig(({ command, mode }) => {
const viteEnv = loadEnv(mode, process.cwd(), "");
const serviceHost = viteEnv.VITE_ORPHION_SERVICE_HOST || "127.0.0.1";
const apiPort = viteEnv.VITE_API_PORT || "4000";
const apiProxyTarget = viteEnv.VITE_API_PROXY_TARGET || `http://${serviceHost}:${apiPort}`;
const https = command === "serve" ? getDevHttpsOptions() : undefined;
return {
plugins: [
@ -30,6 +72,7 @@ export default defineConfig(({ mode }) => {
dedupe: ["react", "react-dom", "@tanstack/react-query", "@tanstack/query-core"],
},
server: {
...(https ? { https } : {}),
proxy: {
"/api": {
target: apiProxyTarget,
@ -37,6 +80,7 @@ export default defineConfig(({ mode }) => {
},
},
},
...(https ? { preview: { https } } : {}),
build: {
rollupOptions: {
output: {

View File

@ -8,9 +8,15 @@ Type=simple
User=cezen
Group=cezen
WorkingDirectory=/home/cezen/whisper
EnvironmentFile=-/home/cezen/whisper/.env
Environment=WHISPER_MODEL=large-v3
Environment=WHISPER_MODEL_DIR=/home/cezen/whisper/models
Environment=WHISPER_FILE_FIELD=file
Environment=WHISPERX_DIARIZATION=true
Environment=WHISPERX_DEVICE=cuda
Environment=WHISPERX_COMPUTE_TYPE=float16
Environment=WHISPERX_BATCH_SIZE=8
Environment=WHISPERX_DIARIZATION_MODEL=pyannote/speaker-diarization-community-1
ExecStart=/home/cezen/whisper/venv/bin/python /home/cezen/whisper/server.py --host 0.0.0.0 --port 8000
Restart=always
RestartSec=5

View File

@ -1,5 +1,5 @@
#!/usr/bin/env python3
"""Tiny Faster-Whisper HTTP API for Orphion.
"""Tiny Faster-Whisper/WhisperX HTTP API for Orphion.
Endpoints:
GET /health
@ -10,6 +10,8 @@ from __future__ import annotations
import argparse
import cgi
import gc
import inspect
import json
import os
import tempfile
@ -21,6 +23,21 @@ from pathlib import Path
from faster_whisper import WhisperModel
try:
import torch
except Exception: # pragma: no cover - torch is optional for CPU fallback
torch = None
try:
import whisperx
except Exception: # pragma: no cover - keeps existing Faster-Whisper path available
whisperx = None
try:
import whisperx.diarize as whisperx_diarize
except Exception: # pragma: no cover - older/newer WhisperX layouts vary
whisperx_diarize = None
try:
import ctranslate2
except Exception: # pragma: no cover - only used for runtime device selection
@ -29,6 +46,11 @@ except Exception: # pragma: no cover - only used for runtime device selection
MODEL_LOCK = threading.Lock()
MODEL = None
WHISPERX_LOCK = threading.Lock()
INFERENCE_LOCK = threading.Lock()
WHISPERX_MODEL = None
ALIGN_MODELS = {}
DIARIZATION_PIPELINE = None
def choose_device() -> str:
@ -50,7 +72,35 @@ def default_compute_type(device: str) -> str:
configured = os.getenv("WHISPER_COMPUTE_TYPE", "").strip()
if configured:
return configured
return "float16" if device == "cuda" else "int8"
return "float16" if device.startswith("cuda") else "int8"
def diarization_enabled() -> bool:
return os.getenv("WHISPERX_DIARIZATION", "true").lower() in {"1", "true", "yes", "on"}
def whisperx_device() -> str:
configured = os.getenv("WHISPERX_DEVICE", "").strip()
if configured:
return configured
selected = choose_device()
return "cuda" if selected == "cuda" else selected
def whisperx_compute_type(device: str) -> str:
configured = os.getenv("WHISPERX_COMPUTE_TYPE", "").strip()
if configured:
return configured
return "float16" if device.startswith("cuda") else default_compute_type(device)
def huggingface_token() -> str | None:
return (
os.getenv("HUGGINGFACE_TOKEN")
or os.getenv("HF_TOKEN")
or os.getenv("PYANNOTE_AUTH_TOKEN")
or None
)
def get_model() -> WhisperModel:
@ -71,6 +121,196 @@ def get_model() -> WhisperModel:
return MODEL
def get_whisperx_model():
global WHISPERX_MODEL
if WHISPERX_MODEL is not None:
return WHISPERX_MODEL
if whisperx is None:
raise RuntimeError("whisperx is not installed")
with WHISPERX_LOCK:
if WHISPERX_MODEL is None:
device = whisperx_device()
model_name = os.getenv("WHISPERX_MODEL", os.getenv("WHISPER_MODEL", "large-v3"))
kwargs = {
"compute_type": whisperx_compute_type(device),
"download_root": os.getenv("WHISPER_MODEL_DIR") or None,
}
try:
WHISPERX_MODEL = whisperx.load_model(model_name, device, **kwargs)
except TypeError:
kwargs.pop("download_root", None)
WHISPERX_MODEL = whisperx.load_model(model_name, device, **kwargs)
return WHISPERX_MODEL
def get_align_model(language_code: str | None, device: str):
if whisperx is None:
raise RuntimeError("whisperx is not installed")
language = language_code or os.getenv("WHISPERX_ALIGN_LANGUAGE", "en")
key = (language, device)
if key not in ALIGN_MODELS:
ALIGN_MODELS[key] = whisperx.load_align_model(language_code=language, device=device)
return ALIGN_MODELS[key]
def get_diarization_pipeline(device: str):
global DIARIZATION_PIPELINE
if whisperx is None:
raise RuntimeError("whisperx is not installed")
if DIARIZATION_PIPELINE is not None:
return DIARIZATION_PIPELINE
token = huggingface_token()
if not token:
raise RuntimeError("HuggingFace token is required for speaker diarization")
with WHISPERX_LOCK:
if DIARIZATION_PIPELINE is None:
pipeline_factory = getattr(whisperx, "DiarizationPipeline", None)
if pipeline_factory is None and whisperx_diarize is not None:
pipeline_factory = getattr(whisperx_diarize, "DiarizationPipeline", None)
if pipeline_factory is None:
raise RuntimeError("WhisperX diarization pipeline is not available")
parameters = inspect.signature(pipeline_factory).parameters
if "use_auth_token" in parameters:
DIARIZATION_PIPELINE = pipeline_factory(use_auth_token=token, device=device)
elif "auth_token" in parameters:
DIARIZATION_PIPELINE = pipeline_factory(auth_token=token, device=device)
elif "token" in parameters:
DIARIZATION_PIPELINE = pipeline_factory(
model_name=os.getenv("WHISPERX_DIARIZATION_MODEL") or None,
token=token,
device=device,
)
else:
DIARIZATION_PIPELINE = pipeline_factory(device=device)
return DIARIZATION_PIPELINE
def cleanup_gpu_memory() -> None:
gc.collect()
if torch is not None and getattr(torch, "cuda", None) is not None:
try:
if torch.cuda.is_available():
torch.cuda.empty_cache()
except Exception:
pass
def normalize_speaker(raw: str | None, mapping: dict[str, str]) -> str:
if not raw:
return "Speaker 1"
key = str(raw)
if key not in mapping:
mapping[key] = f"Speaker {len(mapping) + 1}"
return mapping[key]
def segment_speaker(segment: dict) -> str | None:
if segment.get("speaker"):
return segment.get("speaker")
counts: dict[str, int] = {}
for word in segment.get("words") or []:
speaker = word.get("speaker")
if speaker:
counts[speaker] = counts.get(speaker, 0) + 1
if not counts:
return None
return max(counts.items(), key=lambda item: item[1])[0]
def normalize_segments(segments: list[dict], include_speakers: bool = True) -> list[dict]:
speaker_map: dict[str, str] = {}
normalized = []
for segment in segments:
text = str(segment.get("text") or "").strip()
if not text:
continue
item = {
"start": round(float(segment.get("start") or 0), 3),
"end": round(float(segment.get("end") or 0), 3),
"text": text,
}
if include_speakers:
item["speaker"] = normalize_speaker(segment_speaker(segment), speaker_map)
normalized.append(item)
return normalized
def transcribe_with_whisperx(audio_path: str) -> dict:
device = whisperx_device()
model = get_whisperx_model()
audio = whisperx.load_audio(audio_path)
batch_size = int(os.getenv("WHISPERX_BATCH_SIZE", os.getenv("WHISPER_BATCH_SIZE", "8")))
result = model.transcribe(audio, batch_size=batch_size)
language = result.get("language")
try:
align_model, metadata = get_align_model(language, device)
result = whisperx.align(
result.get("segments", []),
align_model,
metadata,
audio,
device,
return_char_alignments=False,
)
except Exception:
traceback.print_exc()
diarization_failed = False
try:
diarize_model = get_diarization_pipeline(device)
min_speakers = os.getenv("WHISPERX_MIN_SPEAKERS")
max_speakers = os.getenv("WHISPERX_MAX_SPEAKERS")
kwargs = {}
if min_speakers:
kwargs["min_speakers"] = int(min_speakers)
if max_speakers:
kwargs["max_speakers"] = int(max_speakers)
diarize_segments = diarize_model(audio, **kwargs)
result = whisperx.assign_word_speakers(diarize_segments, result)
except Exception:
diarization_failed = True
traceback.print_exc()
timestamps = normalize_segments(result.get("segments", []), include_speakers=not diarization_failed)
transcript_text = " ".join(segment["text"] for segment in timestamps if segment.get("text"))
return {
"transcript_text": transcript_text,
"language": language,
"duration": float(audio.shape[0]) / 16000 if hasattr(audio, "shape") else None,
"timestamps": timestamps,
"diarization": "fallback" if diarization_failed else "completed",
}
def transcribe_with_faster_whisper(audio_path: str) -> dict:
model = get_model()
segments, info = model.transcribe(
audio_path,
beam_size=int(os.getenv("WHISPER_BEAM_SIZE", "5")),
vad_filter=os.getenv("WHISPER_VAD_FILTER", "true").lower() in {"1", "true", "yes", "on"},
)
timestamps = []
transcript_parts = []
for segment in segments:
text = segment.text.strip()
transcript_parts.append(text)
timestamps.append({"start": segment.start, "end": segment.end, "text": text})
return {
"transcript_text": " ".join(part for part in transcript_parts if part),
"language": getattr(info, "language", None),
"duration": getattr(info, "duration", None),
"timestamps": timestamps,
"diarization": "disabled",
}
class WhisperHandler(BaseHTTPRequestHandler):
server_version = "OrphionWhisper/1.0"
@ -84,6 +324,11 @@ class WhisperHandler(BaseHTTPRequestHandler):
"status": "ok",
"model": os.getenv("WHISPER_MODEL", "large-v3"),
"device": choose_device(),
"whisperx": whisperx is not None,
"diarization": diarization_enabled(),
"diarization_ready": whisperx is not None
and diarization_enabled()
and bool(huggingface_token()),
},
)
@ -129,28 +374,20 @@ class WhisperHandler(BaseHTTPRequestHandler):
temp_file.write(chunk)
try:
model = get_model()
segments, info = model.transcribe(
temp_path,
beam_size=int(os.getenv("WHISPER_BEAM_SIZE", "5")),
vad_filter=os.getenv("WHISPER_VAD_FILTER", "true").lower() in {"1", "true", "yes", "on"},
)
timestamps = []
transcript_parts = []
for segment in segments:
text = segment.text.strip()
transcript_parts.append(text)
timestamps.append({"start": segment.start, "end": segment.end, "text": text})
with INFERENCE_LOCK:
try:
if diarization_enabled():
payload = transcribe_with_whisperx(temp_path)
else:
payload = transcribe_with_faster_whisper(temp_path)
except Exception:
traceback.print_exc()
payload = transcribe_with_faster_whisper(temp_path)
payload["diarization"] = "fallback"
self.send_json(
{
"transcript_text": " ".join(part for part in transcript_parts if part),
"language": getattr(info, "language", None),
"duration": getattr(info, "duration", None),
"timestamps": timestamps,
},
)
self.send_json(payload)
finally:
cleanup_gpu_memory()
try:
os.unlink(temp_path)
except OSError: