332 lines
11 KiB
Python
332 lines
11 KiB
Python
# beam_app/app.py - Combined TTS + Timestamp Endpoint (Fixed)
|
|
|
|
import os
|
|
import re
|
|
import base64
|
|
from io import BytesIO
|
|
from dataclasses import dataclass
|
|
|
|
from beam import endpoint, Image
|
|
|
|
# ====================================================
|
|
# Container Image
|
|
# ====================================================
|
|
gpu_image = (
|
|
Image(python_version="python3.11")
|
|
.add_python_packages([
|
|
"torch>=2.4.0",
|
|
"torchaudio>=2.4.0",
|
|
"transformers",
|
|
"scipy",
|
|
"soundfile>=0.12.0",
|
|
"kokoro>=0.1.0",
|
|
"huggingface_hub",
|
|
])
|
|
.add_commands([
|
|
"apt-get update && apt-get install -y libsndfile1 ffmpeg",
|
|
])
|
|
)
|
|
|
|
# ====================================================
|
|
# Voice IDs
|
|
# ====================================================
|
|
VOICE_IDS = [
|
|
"af_alloy", "af_aoede", "af_bella", "af_heart", "af_jessica",
|
|
"af_kore", "af_nicole", "af_nova", "af_river", "af_sarah", "af_sky",
|
|
"am_adam", "am_echo", "am_eric", "am_fenrir", "am_liam",
|
|
"am_michael", "am_onyx", "am_puck", "am_santa",
|
|
"bf_alice", "bf_emma", "bf_isabella", "bf_lily",
|
|
"bm_daniel", "bm_fable", "bm_george", "bm_lewis",
|
|
"ef_dora", "em_alex", "em_santa",
|
|
"ff_siwis",
|
|
"hf_alpha", "hf_beta", "hm_omega", "hm_psi",
|
|
"if_sara", "im_nicola",
|
|
"jf_alpha", "jf_gongitsune", "jf_nezumi", "jf_tebukuro", "jm_kumo",
|
|
"pf_dora", "pm_alex", "pm_santa",
|
|
"zf_xiaobei", "zf_xiaoni", "zf_xiaoxiao", "zf_xiaoyi",
|
|
"zm_yunjian", "zm_yunxi", "zm_yunxia", "zm_yunyang",
|
|
]
|
|
|
|
|
|
@dataclass
|
|
class TokenSpan:
|
|
start: int
|
|
end: int
|
|
token: int
|
|
|
|
|
|
# ====================================================
|
|
# on_start — Model Loading
|
|
# ====================================================
|
|
def load_all_models():
|
|
"""Container start হলে একবার চলবে।"""
|
|
import torch
|
|
import torchaudio
|
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
torch_device = torch.device(device)
|
|
|
|
print(f"{'='*50}")
|
|
print(f"🚀 CONTAINER STARTING — Device: {device}")
|
|
print(f"{'='*50}")
|
|
|
|
# --- 1. Kokoro TTS ---
|
|
print(f"📂 [1/2] Loading Kokoro TTS...")
|
|
try:
|
|
from kokoro import KPipeline
|
|
pipeline = KPipeline(lang_code='a', device=device)
|
|
print(f"✅ [1/2] Kokoro TTS loaded")
|
|
except Exception as e:
|
|
print(f"❌ [1/2] Kokoro TTS FAILED: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
# Return partial — TTS ছাড়া কাজ হবে না
|
|
raise e
|
|
|
|
# --- 2. wav2vec2 Aligner ---
|
|
print(f"📂 [2/2] Loading wav2vec2 Aligner...")
|
|
try:
|
|
bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
|
|
aligner_model = bundle.get_model().to(torch_device)
|
|
labels = bundle.get_labels()
|
|
dictionary = {c: i for i, c in enumerate(labels)}
|
|
print(f"✅ [2/2] wav2vec2 Aligner loaded")
|
|
except Exception as e:
|
|
print(f"⚠️ [2/2] wav2vec2 FAILED: {e} — will skip alignment")
|
|
import traceback
|
|
traceback.print_exc()
|
|
bundle = None
|
|
aligner_model = None
|
|
labels = None
|
|
dictionary = None
|
|
|
|
print(f"{'='*50}")
|
|
print(f"🎉 CONTAINER READY")
|
|
print(f"{'='*50}")
|
|
|
|
return {
|
|
"pipeline": pipeline,
|
|
"device": device,
|
|
"torch_device": torch_device,
|
|
"aligner_model": aligner_model,
|
|
"bundle": bundle,
|
|
"labels": labels,
|
|
"dictionary": dictionary,
|
|
}
|
|
|
|
|
|
# ====================================================
|
|
# Combined Endpoint
|
|
# ====================================================
|
|
@endpoint(
|
|
name="tts-combined",
|
|
image=gpu_image,
|
|
on_start=load_all_models,
|
|
gpu="RTX4090",
|
|
cpu=2,
|
|
memory="16Gi",
|
|
keep_warm_seconds=180,
|
|
)
|
|
def generate_audio_with_timestamps(context, **inputs):
|
|
"""TTS + Forced Alignment একই GPU তে।"""
|
|
import torch
|
|
import torchaudio
|
|
import soundfile as sf
|
|
|
|
print(f"")
|
|
print(f"📥 REQUEST RECEIVED")
|
|
print(f" Keys: {list(inputs.keys())}")
|
|
print(f" text length: {len(inputs.get('text', ''))}")
|
|
|
|
# ---- Get models from on_start ----
|
|
ctx = context.on_start_value
|
|
|
|
if ctx is None:
|
|
print(f"❌ on_start_value is None!")
|
|
return {"error": "Models not loaded", "success": False}
|
|
|
|
pipeline = ctx["pipeline"]
|
|
aligner_model = ctx.get("aligner_model")
|
|
bundle = ctx.get("bundle")
|
|
dictionary = ctx.get("dictionary")
|
|
torch_device = ctx["torch_device"]
|
|
|
|
# ---- Parse inputs ----
|
|
text = inputs.get("text", "")
|
|
voice = inputs.get("voice", "af_heart")
|
|
speed = inputs.get("speed", 1.0)
|
|
skip_alignment = inputs.get("skip_alignment", False)
|
|
|
|
if not text or len(str(text).strip()) < 2:
|
|
return {"error": "Text is required (min 2 chars)", "success": False}
|
|
|
|
if voice not in VOICE_IDS:
|
|
return {"error": f"Invalid voice '{voice}'", "success": False}
|
|
|
|
try:
|
|
speed = max(0.5, min(2.0, float(speed)))
|
|
except (TypeError, ValueError):
|
|
speed = 1.0
|
|
|
|
# =================================================
|
|
# STEP 1: TTS Generation
|
|
# =================================================
|
|
try:
|
|
print(f"🔊 TTS: voice={voice}, speed={speed}, text={len(text)} chars")
|
|
print(f" Preview: {text[:80]}...")
|
|
|
|
generator = pipeline(text, voice=voice, speed=speed, split_pattern=r'\n+')
|
|
all_audio = []
|
|
for gs, ps, audio in generator:
|
|
all_audio.append(audio)
|
|
|
|
if not all_audio:
|
|
print(f"❌ No audio chunks generated")
|
|
return {"error": "No audio generated", "success": False}
|
|
|
|
full_audio = torch.cat(all_audio, dim=0)
|
|
sample_rate = 24000
|
|
|
|
# Encode to base64
|
|
audio_buffer = BytesIO()
|
|
sf.write(audio_buffer, full_audio.cpu().numpy(), sample_rate, format='WAV')
|
|
audio_buffer.seek(0)
|
|
audio_bytes = audio_buffer.read()
|
|
audio_base64 = base64.b64encode(audio_bytes).decode('utf-8')
|
|
|
|
print(f"✅ TTS done: {len(audio_base64)} bytes base64, {len(audio_bytes)} bytes raw")
|
|
|
|
except Exception as e:
|
|
print(f"❌ TTS FAILED: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
return {"error": f"TTS failed: {str(e)}", "success": False}
|
|
|
|
# =================================================
|
|
# STEP 2: Forced Alignment
|
|
# =================================================
|
|
timestamps = []
|
|
|
|
if not skip_alignment and aligner_model is not None and bundle is not None and dictionary is not None:
|
|
try:
|
|
print(f"⏳ Aligning {len(text)} chars...")
|
|
|
|
# Audio tensor — GPU memory তে আছে, re-decode দরকার নেই
|
|
waveform = full_audio.unsqueeze(0).float().to(torch_device)
|
|
|
|
# Resample 24000 → 16000
|
|
if sample_rate != bundle.sample_rate:
|
|
resampler = torchaudio.transforms.Resample(
|
|
orig_freq=sample_rate, new_freq=bundle.sample_rate
|
|
).to(torch_device)
|
|
waveform = resampler(waveform)
|
|
|
|
# Text → tokens
|
|
original_words = text.split()
|
|
clean_words = []
|
|
valid_indices = []
|
|
for idx, word in enumerate(original_words):
|
|
clean = re.sub(r'[^a-zA-Z0-9]', '', word).upper()
|
|
if clean:
|
|
clean_words.append(clean)
|
|
valid_indices.append(idx)
|
|
transcript = " ".join(clean_words)
|
|
|
|
if transcript:
|
|
token_indices = []
|
|
for char in transcript:
|
|
if char == ' ':
|
|
token_indices.append(dictionary.get('|', 0))
|
|
else:
|
|
token_indices.append(dictionary.get(char, dictionary.get('|', 0)))
|
|
|
|
targets = torch.tensor(
|
|
token_indices, dtype=torch.int32, device=torch_device
|
|
).unsqueeze(0)
|
|
|
|
with torch.inference_mode():
|
|
emissions, _ = aligner_model(waveform)
|
|
emissions = torch.log_softmax(emissions, dim=-1)
|
|
|
|
input_lengths = torch.tensor([emissions.size(1)], device=torch_device)
|
|
target_lengths = torch.tensor([targets.size(1)], device=torch_device)
|
|
|
|
path, _ = torchaudio.functional.forced_align(
|
|
emissions, targets, input_lengths, target_lengths, blank=0
|
|
)
|
|
path = path[0].tolist()
|
|
|
|
# Parse segments
|
|
segments = []
|
|
if path:
|
|
current_label = path[0]
|
|
start_frame = 0
|
|
for t, label in enumerate(path):
|
|
if label != current_label:
|
|
if current_label != 0:
|
|
segments.append(
|
|
TokenSpan(start=start_frame, end=t, token=current_label)
|
|
)
|
|
current_label = label
|
|
start_frame = t
|
|
if current_label != 0:
|
|
segments.append(
|
|
TokenSpan(start=start_frame, end=len(path), token=current_label)
|
|
)
|
|
|
|
if path and len(path) > 0:
|
|
ratio = waveform.size(1) / len(path) / bundle.sample_rate
|
|
|
|
def get_sec(frame):
|
|
return round(frame * ratio, 2)
|
|
|
|
segment_idx = 0
|
|
for i, word_str in enumerate(transcript.split()):
|
|
word_len = len(word_str)
|
|
if segment_idx + word_len > len(segments):
|
|
break
|
|
t_start = segments[segment_idx].start
|
|
t_end = segments[segment_idx + word_len - 1].end
|
|
timestamps.append({
|
|
"word": original_words[valid_indices[i]],
|
|
"start": get_sec(t_start),
|
|
"end": get_sec(t_end),
|
|
})
|
|
segment_idx += word_len
|
|
if (segment_idx < len(segments)
|
|
and segments[segment_idx].token == dictionary.get('|', 0)):
|
|
segment_idx += 1
|
|
|
|
print(f"✅ Aligned {len(timestamps)} words")
|
|
|
|
except Exception as e:
|
|
print(f"⚠️ Alignment failed (audio still valid): {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
timestamps = []
|
|
else:
|
|
if skip_alignment:
|
|
print(f"⏭️ Alignment skipped (skip_alignment=True)")
|
|
else:
|
|
print(f"⚠️ Alignment skipped (model not loaded)")
|
|
|
|
# =================================================
|
|
# Return Result
|
|
# =================================================
|
|
result = {
|
|
"success": True,
|
|
"audio_base64": audio_base64,
|
|
"audio_format": "wav",
|
|
"sample_rate": sample_rate,
|
|
"voice": voice,
|
|
"speed": speed,
|
|
"text_length": len(text),
|
|
"timestamps": timestamps,
|
|
"word_count": len(timestamps),
|
|
}
|
|
|
|
print(f"📤 RESPONSE: success=True, audio={len(audio_base64)} bytes, words={len(timestamps)}")
|
|
print(f"")
|
|
|
|
return result
|