#!/usr/bin/env python3
"""
Chatterbox Voice Clone Generator

Clone any voice from a single reference clip and generate speech.

Usage:
    python gen_voice.py --voice voices/ref.wav "Your text here"
    python gen_voice.py --voice voices/ref.wav --file script.txt
    python gen_voice.py --voice voices/ref.wav --file script.txt --chunk
    python gen_voice.py --list
"""

import argparse
import glob
import os
import re
import sys

import torch
import torchaudio

# ============================================================
# CONFIG
# ============================================================
VOICES_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "voices")
OUTPUT_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "output")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# ============================================================
# MODEL
# ============================================================
_model = None


def get_model():
    """Lazy-load Chatterbox Turbo (downloads weights on first run)."""
    global _model
    if _model is None:
        print(f"Loading Chatterbox Turbo on {DEVICE}...")
        from chatterbox.tts_turbo import ChatterboxTurboTTS
        _model = ChatterboxTurboTTS.from_pretrained(device=DEVICE)
        print("Model ready.\n")
    return _model


# ============================================================
# GENERATION
# ============================================================

def generate_single(text, voice_path, output_path):
    """Generate one audio clip."""
    model = get_model()
    print(f"Voice:  {os.path.basename(voice_path)}")
    print(f"Text:   \"{text[:100]}{'...' if len(text) > 100 else ''}\"")
    print(f"Generating...")

    wav = model.generate(text, audio_prompt_path=voice_path)

    os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)
    torchaudio.save(output_path, wav, model.sr)

    duration = wav.shape[1] / model.sr
    print(f"Saved:  {output_path} ({duration:.1f}s)")
    return output_path


def generate_chunked(text, voice_path, output_path):
    """Split into sentences, generate each, stitch with silence gaps."""
    model = get_model()

    sentences = re.split(r'(?<=[.!?])\s+', text.strip())
    sentences = [s.strip() for s in sentences if s.strip()]

    if not sentences:
        print("No text to generate.")
        return

    print(f"Voice:   {os.path.basename(voice_path)}")
    print(f"Chunks:  {len(sentences)}")
    print()

    all_wavs = []
    for i, sentence in enumerate(sentences):
        print(f"  [{i+1}/{len(sentences)}] \"{sentence[:70]}{'...' if len(sentence) > 70 else ''}\"")
        wav = model.generate(sentence, audio_prompt_path=voice_path)
        all_wavs.append(wav)
        # 0.4s silence between sentences
        silence = torch.zeros(1, int(model.sr * 0.4))
        all_wavs.append(silence)

    combined = torch.cat(all_wavs, dim=1)

    os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)
    torchaudio.save(output_path, combined, model.sr)

    duration = combined.shape[1] / model.sr
    print(f"\nSaved:  {output_path} ({duration:.1f}s, {len(sentences)} chunks)")
    return output_path


def list_voices():
    """List available reference clips in voices/ folder."""
    clips = glob.glob(os.path.join(VOICES_DIR, "*.wav")) + \
            glob.glob(os.path.join(VOICES_DIR, "*.mp3"))
    if not clips:
        print("No voice clips found in voices/ folder.")
        print("Drop a 15-30 second WAV or MP3 reference clip in there.")
        return
    print("Available voices:\n")
    for clip in sorted(clips):
        name = os.path.basename(clip)
        # Get duration
        try:
            info = torchaudio.info(clip)
            dur = info.num_frames / info.sample_rate
            print(f"  {name:<30} ({dur:.1f}s)")
        except Exception:
            print(f"  {name}")


# ============================================================
# MAIN
# ============================================================

def main():
    parser = argparse.ArgumentParser(
        description="Clone any voice and generate speech",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
  python gen_voice.py --voice voices/ref.wav "Hello world"
  python gen_voice.py --voice voices/ref.wav --file script.txt --chunk
  python gen_voice.py --list
        """
    )
    parser.add_argument("text", nargs="?", help="Text to speak")
    parser.add_argument("--voice", "-v", help="Path to reference voice clip (WAV/MP3)")
    parser.add_argument("--file", "-f", help="Read text from file")
    parser.add_argument("--output", "-o", default=None, help="Output WAV path")
    parser.add_argument("--chunk", "-c", action="store_true",
                       help="Process one sentence at a time (saves VRAM)")
    parser.add_argument("--list", "-l", action="store_true",
                       help="List available voice clips")

    args = parser.parse_args()

    # List mode
    if args.list:
        list_voices()
        return

    # Validate voice — auto-detect if only one clip exists
    if not args.voice:
        clips = glob.glob(os.path.join(VOICES_DIR, "*.wav")) + \
                glob.glob(os.path.join(VOICES_DIR, "*.mp3"))
        if len(clips) == 1:
            args.voice = clips[0]
            print(f"Auto-detected voice: {os.path.basename(args.voice)}")
        elif len(clips) > 1:
            print("Multiple voices found. Specify one with --voice:")
            for c in sorted(clips):
                print(f"  --voice {c}")
            sys.exit(1)
        else:
            print("ERROR: No voice clips found in voices/ folder.")
            print("Drop a reference WAV in voices/ or use --voice path/to/clip.wav")
            sys.exit(1)

    if not os.path.exists(args.voice):
        print(f"ERROR: Voice clip not found: {args.voice}")
        sys.exit(1)

    # Get text
    if args.file:
        with open(args.file, "r") as f:
            text = f.read().strip()
    elif args.text:
        text = args.text
    else:
        print("Provide text as argument or use --file")
        sys.exit(1)

    # Output path
    if args.output:
        output_path = args.output
    else:
        os.makedirs(OUTPUT_DIR, exist_ok=True)
        voice_name = os.path.splitext(os.path.basename(args.voice))[0]
        output_path = os.path.join(OUTPUT_DIR, f"{voice_name}_output.wav")

    # Generate
    if args.chunk or len(text) > 500:
        generate_chunked(text, args.voice, output_path)
    else:
        generate_single(text, args.voice, output_path)

    print("\nDone!")


if __name__ == "__main__":
    main()
