Files
audiobook-maker-pro-v4.2/beam_app/app.py
2026-05-22 18:28:47 +06:00

332 lines
11 KiB
Python

# beam_app/app.py - Combined TTS + Timestamp Endpoint (Fixed)
import os
import re
import base64
from io import BytesIO
from dataclasses import dataclass
from beam import endpoint, Image
# ====================================================
# Container Image
# ====================================================
gpu_image = (
Image(python_version="python3.11")
.add_python_packages([
"torch>=2.4.0",
"torchaudio>=2.4.0",
"transformers",
"scipy",
"soundfile>=0.12.0",
"kokoro>=0.1.0",
"huggingface_hub",
])
.add_commands([
"apt-get update && apt-get install -y libsndfile1 ffmpeg",
])
)
# ====================================================
# Voice IDs
# ====================================================
VOICE_IDS = [
"af_alloy", "af_aoede", "af_bella", "af_heart", "af_jessica",
"af_kore", "af_nicole", "af_nova", "af_river", "af_sarah", "af_sky",
"am_adam", "am_echo", "am_eric", "am_fenrir", "am_liam",
"am_michael", "am_onyx", "am_puck", "am_santa",
"bf_alice", "bf_emma", "bf_isabella", "bf_lily",
"bm_daniel", "bm_fable", "bm_george", "bm_lewis",
"ef_dora", "em_alex", "em_santa",
"ff_siwis",
"hf_alpha", "hf_beta", "hm_omega", "hm_psi",
"if_sara", "im_nicola",
"jf_alpha", "jf_gongitsune", "jf_nezumi", "jf_tebukuro", "jm_kumo",
"pf_dora", "pm_alex", "pm_santa",
"zf_xiaobei", "zf_xiaoni", "zf_xiaoxiao", "zf_xiaoyi",
"zm_yunjian", "zm_yunxi", "zm_yunxia", "zm_yunyang",
]
@dataclass
class TokenSpan:
start: int
end: int
token: int
# ====================================================
# on_start — Model Loading
# ====================================================
def load_all_models():
"""Container start হলে একবার চলবে।"""
import torch
import torchaudio
device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch_device = torch.device(device)
print(f"{'='*50}")
print(f"🚀 CONTAINER STARTING — Device: {device}")
print(f"{'='*50}")
# --- 1. Kokoro TTS ---
print(f"📂 [1/2] Loading Kokoro TTS...")
try:
from kokoro import KPipeline
pipeline = KPipeline(lang_code='a', device=device)
print(f"✅ [1/2] Kokoro TTS loaded")
except Exception as e:
print(f"❌ [1/2] Kokoro TTS FAILED: {e}")
import traceback
traceback.print_exc()
# Return partial — TTS ছাড়া কাজ হবে না
raise e
# --- 2. wav2vec2 Aligner ---
print(f"📂 [2/2] Loading wav2vec2 Aligner...")
try:
bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
aligner_model = bundle.get_model().to(torch_device)
labels = bundle.get_labels()
dictionary = {c: i for i, c in enumerate(labels)}
print(f"✅ [2/2] wav2vec2 Aligner loaded")
except Exception as e:
print(f"⚠️ [2/2] wav2vec2 FAILED: {e} — will skip alignment")
import traceback
traceback.print_exc()
bundle = None
aligner_model = None
labels = None
dictionary = None
print(f"{'='*50}")
print(f"🎉 CONTAINER READY")
print(f"{'='*50}")
return {
"pipeline": pipeline,
"device": device,
"torch_device": torch_device,
"aligner_model": aligner_model,
"bundle": bundle,
"labels": labels,
"dictionary": dictionary,
}
# ====================================================
# Combined Endpoint
# ====================================================
@endpoint(
name="tts-combined",
image=gpu_image,
on_start=load_all_models,
gpu="RTX4090",
cpu=2,
memory="16Gi",
keep_warm_seconds=180,
)
def generate_audio_with_timestamps(context, **inputs):
"""TTS + Forced Alignment একই GPU তে।"""
import torch
import torchaudio
import soundfile as sf
print(f"")
print(f"📥 REQUEST RECEIVED")
print(f" Keys: {list(inputs.keys())}")
print(f" text length: {len(inputs.get('text', ''))}")
# ---- Get models from on_start ----
ctx = context.on_start_value
if ctx is None:
print(f"❌ on_start_value is None!")
return {"error": "Models not loaded", "success": False}
pipeline = ctx["pipeline"]
aligner_model = ctx.get("aligner_model")
bundle = ctx.get("bundle")
dictionary = ctx.get("dictionary")
torch_device = ctx["torch_device"]
# ---- Parse inputs ----
text = inputs.get("text", "")
voice = inputs.get("voice", "af_heart")
speed = inputs.get("speed", 1.0)
skip_alignment = inputs.get("skip_alignment", False)
if not text or len(str(text).strip()) < 2:
return {"error": "Text is required (min 2 chars)", "success": False}
if voice not in VOICE_IDS:
return {"error": f"Invalid voice '{voice}'", "success": False}
try:
speed = max(0.5, min(2.0, float(speed)))
except (TypeError, ValueError):
speed = 1.0
# =================================================
# STEP 1: TTS Generation
# =================================================
try:
print(f"🔊 TTS: voice={voice}, speed={speed}, text={len(text)} chars")
print(f" Preview: {text[:80]}...")
generator = pipeline(text, voice=voice, speed=speed, split_pattern=r'\n+')
all_audio = []
for gs, ps, audio in generator:
all_audio.append(audio)
if not all_audio:
print(f"❌ No audio chunks generated")
return {"error": "No audio generated", "success": False}
full_audio = torch.cat(all_audio, dim=0)
sample_rate = 24000
# Encode to base64
audio_buffer = BytesIO()
sf.write(audio_buffer, full_audio.cpu().numpy(), sample_rate, format='WAV')
audio_buffer.seek(0)
audio_bytes = audio_buffer.read()
audio_base64 = base64.b64encode(audio_bytes).decode('utf-8')
print(f"✅ TTS done: {len(audio_base64)} bytes base64, {len(audio_bytes)} bytes raw")
except Exception as e:
print(f"❌ TTS FAILED: {e}")
import traceback
traceback.print_exc()
return {"error": f"TTS failed: {str(e)}", "success": False}
# =================================================
# STEP 2: Forced Alignment
# =================================================
timestamps = []
if not skip_alignment and aligner_model is not None and bundle is not None and dictionary is not None:
try:
print(f"⏳ Aligning {len(text)} chars...")
# Audio tensor — GPU memory তে আছে, re-decode দরকার নেই
waveform = full_audio.unsqueeze(0).float().to(torch_device)
# Resample 24000 → 16000
if sample_rate != bundle.sample_rate:
resampler = torchaudio.transforms.Resample(
orig_freq=sample_rate, new_freq=bundle.sample_rate
).to(torch_device)
waveform = resampler(waveform)
# Text → tokens
original_words = text.split()
clean_words = []
valid_indices = []
for idx, word in enumerate(original_words):
clean = re.sub(r'[^a-zA-Z0-9]', '', word).upper()
if clean:
clean_words.append(clean)
valid_indices.append(idx)
transcript = " ".join(clean_words)
if transcript:
token_indices = []
for char in transcript:
if char == ' ':
token_indices.append(dictionary.get('|', 0))
else:
token_indices.append(dictionary.get(char, dictionary.get('|', 0)))
targets = torch.tensor(
token_indices, dtype=torch.int32, device=torch_device
).unsqueeze(0)
with torch.inference_mode():
emissions, _ = aligner_model(waveform)
emissions = torch.log_softmax(emissions, dim=-1)
input_lengths = torch.tensor([emissions.size(1)], device=torch_device)
target_lengths = torch.tensor([targets.size(1)], device=torch_device)
path, _ = torchaudio.functional.forced_align(
emissions, targets, input_lengths, target_lengths, blank=0
)
path = path[0].tolist()
# Parse segments
segments = []
if path:
current_label = path[0]
start_frame = 0
for t, label in enumerate(path):
if label != current_label:
if current_label != 0:
segments.append(
TokenSpan(start=start_frame, end=t, token=current_label)
)
current_label = label
start_frame = t
if current_label != 0:
segments.append(
TokenSpan(start=start_frame, end=len(path), token=current_label)
)
if path and len(path) > 0:
ratio = waveform.size(1) / len(path) / bundle.sample_rate
def get_sec(frame):
return round(frame * ratio, 2)
segment_idx = 0
for i, word_str in enumerate(transcript.split()):
word_len = len(word_str)
if segment_idx + word_len > len(segments):
break
t_start = segments[segment_idx].start
t_end = segments[segment_idx + word_len - 1].end
timestamps.append({
"word": original_words[valid_indices[i]],
"start": get_sec(t_start),
"end": get_sec(t_end),
})
segment_idx += word_len
if (segment_idx < len(segments)
and segments[segment_idx].token == dictionary.get('|', 0)):
segment_idx += 1
print(f"✅ Aligned {len(timestamps)} words")
except Exception as e:
print(f"⚠️ Alignment failed (audio still valid): {e}")
import traceback
traceback.print_exc()
timestamps = []
else:
if skip_alignment:
print(f"⏭️ Alignment skipped (skip_alignment=True)")
else:
print(f"⚠️ Alignment skipped (model not loaded)")
# =================================================
# Return Result
# =================================================
result = {
"success": True,
"audio_base64": audio_base64,
"audio_format": "wav",
"sample_rate": sample_rate,
"voice": voice,
"speed": speed,
"text_length": len(text),
"timestamps": timestamps,
"word_count": len(timestamps),
}
print(f"📤 RESPONSE: success=True, audio={len(audio_base64)} bytes, words={len(timestamps)}")
print(f"")
return result