# beam_app/app.py - Combined TTS + Timestamp Endpoint (Fixed) import os import re import base64 from io import BytesIO from dataclasses import dataclass from beam import endpoint, Image # ==================================================== # Container Image # ==================================================== gpu_image = ( Image(python_version="python3.11") .add_python_packages([ "torch>=2.4.0", "torchaudio>=2.4.0", "transformers", "scipy", "soundfile>=0.12.0", "kokoro>=0.1.0", "huggingface_hub", ]) .add_commands([ "apt-get update && apt-get install -y libsndfile1 ffmpeg", ]) ) # ==================================================== # Voice IDs # ==================================================== VOICE_IDS = [ "af_alloy", "af_aoede", "af_bella", "af_heart", "af_jessica", "af_kore", "af_nicole", "af_nova", "af_river", "af_sarah", "af_sky", "am_adam", "am_echo", "am_eric", "am_fenrir", "am_liam", "am_michael", "am_onyx", "am_puck", "am_santa", "bf_alice", "bf_emma", "bf_isabella", "bf_lily", "bm_daniel", "bm_fable", "bm_george", "bm_lewis", "ef_dora", "em_alex", "em_santa", "ff_siwis", "hf_alpha", "hf_beta", "hm_omega", "hm_psi", "if_sara", "im_nicola", "jf_alpha", "jf_gongitsune", "jf_nezumi", "jf_tebukuro", "jm_kumo", "pf_dora", "pm_alex", "pm_santa", "zf_xiaobei", "zf_xiaoni", "zf_xiaoxiao", "zf_xiaoyi", "zm_yunjian", "zm_yunxi", "zm_yunxia", "zm_yunyang", ] @dataclass class TokenSpan: start: int end: int token: int # ==================================================== # on_start — Model Loading # ==================================================== def load_all_models(): """Container start হলে একবার চলবে।""" import torch import torchaudio device = 'cuda' if torch.cuda.is_available() else 'cpu' torch_device = torch.device(device) print(f"{'='*50}") print(f"🚀 CONTAINER STARTING — Device: {device}") print(f"{'='*50}") # --- 1. Kokoro TTS --- print(f"📂 [1/2] Loading Kokoro TTS...") try: from kokoro import KPipeline pipeline = KPipeline(lang_code='a', device=device) print(f"✅ [1/2] Kokoro TTS loaded") except Exception as e: print(f"❌ [1/2] Kokoro TTS FAILED: {e}") import traceback traceback.print_exc() # Return partial — TTS ছাড়া কাজ হবে না raise e # --- 2. wav2vec2 Aligner --- print(f"📂 [2/2] Loading wav2vec2 Aligner...") try: bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H aligner_model = bundle.get_model().to(torch_device) labels = bundle.get_labels() dictionary = {c: i for i, c in enumerate(labels)} print(f"✅ [2/2] wav2vec2 Aligner loaded") except Exception as e: print(f"⚠️ [2/2] wav2vec2 FAILED: {e} — will skip alignment") import traceback traceback.print_exc() bundle = None aligner_model = None labels = None dictionary = None print(f"{'='*50}") print(f"🎉 CONTAINER READY") print(f"{'='*50}") return { "pipeline": pipeline, "device": device, "torch_device": torch_device, "aligner_model": aligner_model, "bundle": bundle, "labels": labels, "dictionary": dictionary, } # ==================================================== # Combined Endpoint # ==================================================== @endpoint( name="tts-combined", image=gpu_image, on_start=load_all_models, gpu="RTX4090", cpu=2, memory="16Gi", keep_warm_seconds=180, ) def generate_audio_with_timestamps(context, **inputs): """TTS + Forced Alignment একই GPU তে।""" import torch import torchaudio import soundfile as sf print(f"") print(f"📥 REQUEST RECEIVED") print(f" Keys: {list(inputs.keys())}") print(f" text length: {len(inputs.get('text', ''))}") # ---- Get models from on_start ---- ctx = context.on_start_value if ctx is None: print(f"❌ on_start_value is None!") return {"error": "Models not loaded", "success": False} pipeline = ctx["pipeline"] aligner_model = ctx.get("aligner_model") bundle = ctx.get("bundle") dictionary = ctx.get("dictionary") torch_device = ctx["torch_device"] # ---- Parse inputs ---- text = inputs.get("text", "") voice = inputs.get("voice", "af_heart") speed = inputs.get("speed", 1.0) skip_alignment = inputs.get("skip_alignment", False) if not text or len(str(text).strip()) < 2: return {"error": "Text is required (min 2 chars)", "success": False} if voice not in VOICE_IDS: return {"error": f"Invalid voice '{voice}'", "success": False} try: speed = max(0.5, min(2.0, float(speed))) except (TypeError, ValueError): speed = 1.0 # ================================================= # STEP 1: TTS Generation # ================================================= try: print(f"🔊 TTS: voice={voice}, speed={speed}, text={len(text)} chars") print(f" Preview: {text[:80]}...") generator = pipeline(text, voice=voice, speed=speed, split_pattern=r'\n+') all_audio = [] for gs, ps, audio in generator: all_audio.append(audio) if not all_audio: print(f"❌ No audio chunks generated") return {"error": "No audio generated", "success": False} full_audio = torch.cat(all_audio, dim=0) sample_rate = 24000 # Encode to base64 audio_buffer = BytesIO() sf.write(audio_buffer, full_audio.cpu().numpy(), sample_rate, format='WAV') audio_buffer.seek(0) audio_bytes = audio_buffer.read() audio_base64 = base64.b64encode(audio_bytes).decode('utf-8') print(f"✅ TTS done: {len(audio_base64)} bytes base64, {len(audio_bytes)} bytes raw") except Exception as e: print(f"❌ TTS FAILED: {e}") import traceback traceback.print_exc() return {"error": f"TTS failed: {str(e)}", "success": False} # ================================================= # STEP 2: Forced Alignment # ================================================= timestamps = [] if not skip_alignment and aligner_model is not None and bundle is not None and dictionary is not None: try: print(f"⏳ Aligning {len(text)} chars...") # Audio tensor — GPU memory তে আছে, re-decode দরকার নেই waveform = full_audio.unsqueeze(0).float().to(torch_device) # Resample 24000 → 16000 if sample_rate != bundle.sample_rate: resampler = torchaudio.transforms.Resample( orig_freq=sample_rate, new_freq=bundle.sample_rate ).to(torch_device) waveform = resampler(waveform) # Text → tokens original_words = text.split() clean_words = [] valid_indices = [] for idx, word in enumerate(original_words): clean = re.sub(r'[^a-zA-Z0-9]', '', word).upper() if clean: clean_words.append(clean) valid_indices.append(idx) transcript = " ".join(clean_words) if transcript: token_indices = [] for char in transcript: if char == ' ': token_indices.append(dictionary.get('|', 0)) else: token_indices.append(dictionary.get(char, dictionary.get('|', 0))) targets = torch.tensor( token_indices, dtype=torch.int32, device=torch_device ).unsqueeze(0) with torch.inference_mode(): emissions, _ = aligner_model(waveform) emissions = torch.log_softmax(emissions, dim=-1) input_lengths = torch.tensor([emissions.size(1)], device=torch_device) target_lengths = torch.tensor([targets.size(1)], device=torch_device) path, _ = torchaudio.functional.forced_align( emissions, targets, input_lengths, target_lengths, blank=0 ) path = path[0].tolist() # Parse segments segments = [] if path: current_label = path[0] start_frame = 0 for t, label in enumerate(path): if label != current_label: if current_label != 0: segments.append( TokenSpan(start=start_frame, end=t, token=current_label) ) current_label = label start_frame = t if current_label != 0: segments.append( TokenSpan(start=start_frame, end=len(path), token=current_label) ) if path and len(path) > 0: ratio = waveform.size(1) / len(path) / bundle.sample_rate def get_sec(frame): return round(frame * ratio, 2) segment_idx = 0 for i, word_str in enumerate(transcript.split()): word_len = len(word_str) if segment_idx + word_len > len(segments): break t_start = segments[segment_idx].start t_end = segments[segment_idx + word_len - 1].end timestamps.append({ "word": original_words[valid_indices[i]], "start": get_sec(t_start), "end": get_sec(t_end), }) segment_idx += word_len if (segment_idx < len(segments) and segments[segment_idx].token == dictionary.get('|', 0)): segment_idx += 1 print(f"✅ Aligned {len(timestamps)} words") except Exception as e: print(f"⚠️ Alignment failed (audio still valid): {e}") import traceback traceback.print_exc() timestamps = [] else: if skip_alignment: print(f"⏭️ Alignment skipped (skip_alignment=True)") else: print(f"⚠️ Alignment skipped (model not loaded)") # ================================================= # Return Result # ================================================= result = { "success": True, "audio_base64": audio_base64, "audio_format": "wav", "sample_rate": sample_rate, "voice": voice, "speed": speed, "text_length": len(text), "timestamps": timestamps, "word_count": len(timestamps), } print(f"📤 RESPONSE: success=True, audio={len(audio_base64)} bytes, words={len(timestamps)}") print(f"") return result