ZipVoiceTTSManager.kt

 avatar
unknown
plain_text
3 months ago
15 kB
8
Indexable
package com.angel.zipvoice

import android.content.Context
import android.util.Log
import ai.onnxruntime.*
import org.pytorch.Module
import org.pytorch.Tensor
import org.pytorch.IValue
import java.io.*
import java.nio.*
import java.util.*
import kotlin.math.*
import be.tarsos.dsp.AudioDispatcher
import be.tarsos.dsp.AudioEvent
import be.tarsos.dsp.AudioProcessor
import be.tarsos.dsp.io.jvm.AudioDispatcherFactory
import be.tarsos.dsp.util.fft.FFT
import be.tarsos.dsp.util.fft.HannWindow

class ZipVoiceTTSManager private constructor(private val context: Context) {
    private val tag = "ZipVoiceTTS"
    var isInitialized = false; private set

    // ONNX Runtime
    private var ortEnv: OrtEnvironment? = null
    private var textEncoderSession: OrtSession? = null
    private var fmDecoderSession: OrtSession? = null

    // PyTorch 声码器
    private var vocoderModule: Module? = null

    // 模型参数
    private val featDim = 100
    private val samplingRate = 24000
    private val n_fft = 1024
    private val hop_length = 256
    private val num_mels = 100

    // 词汇表
    private var tokenToId: MutableMap<String, Int> = HashMap()

    interface OnInitListener {
        fun onSuccess()
        fun onError(msg: String)
    }
    private var initListener: OnInitListener? = null
    fun setOnInitListener(listener: OnInitListener) { this.initListener = listener }

    companion object {
        @Volatile
        private var instance: ZipVoiceTTSManager? = null
        fun getInstance(context: Context): ZipVoiceTTSManager {
            return instance ?: synchronized(this) {
                instance ?: ZipVoiceTTSManager(context.applicationContext).also { instance = it }
            }
        }
        fun releaseInstance() {
            instance?.release()
            instance = null
        }
    }

    init { Thread { initModel() }.start() }

    private fun copyAssetToFile(assetPath: String): File {
        val destFile = File(context.filesDir, "zipvoice_" + assetPath.replace('/', '_'))
        if (!destFile.exists()) {
            destFile.parentFile?.mkdirs()
            try {
                context.assets.open(assetPath).use { input ->
                    FileOutputStream(destFile).use { output -> input.copyTo(output) }
                }
                Log.d(tag, "✅ 模型文件已复制: $assetPath")
            } catch (e: IOException) {
                Log.e(tag, "复制文件失败: $assetPath", e)
            }
        }
        return destFile
    }

    private fun loadTokens(tokensFile: File) {
        tokenToId.clear()
        try {
            BufferedReader(FileReader(tokensFile)).use { reader ->
                var line: String?
                while (reader.readLine().also { line = it } != null) {
                    val parts = line!!.trim().split("\t")
                    if (parts.size == 2) tokenToId[parts[0]] = parts[1].toInt()
                }
            }
            Log.i(tag, "✅ 加载了 ${tokenToId.size} 个 token")
        } catch (e: Exception) {
            Log.e(tag, "加载 tokens.txt 失败", e)
        }
    }

    private fun initModel() {
        try {
            Log.i(tag, "========== ZipVoice TTS 初始化开始 ==========")

            // 1. 复制 ONNX 模型、tokens、声码器
            val textEncoderFile = copyAssetToFile("models/tts/zipvoice/text_encoder.onnx")
            val fmDecoderFile = copyAssetToFile("models/tts/zipvoice/fm_decoder.onnx")
            val tokensFile = copyAssetToFile("models/tts/zipvoice/tokens.txt")
            val vocoderFile = copyAssetToFile("models/tts/zipvoice/vocoder.pt")
            loadTokens(tokensFile)

            // 2. 复制并加载标准版 ONNX Runtime 核心库和 JNI 桥接库
            val coreAssetPath = "libs/arm64-v8a/libonnxruntime.so"
            val jniAssetPath = "libs/arm64-v8a/libonnxruntime4j_jni.so"
            val coreFile = File(context.filesDir, "libonnxruntime.so")
            val jniFile = File(context.filesDir, "libonnxruntime4j_jni.so")

            if (!coreFile.exists()) {
                context.assets.open(coreAssetPath).use { input ->
                    FileOutputStream(coreFile).use { output -> input.copyTo(output) }
                }
                Log.i(tag, "✅ 标准版核心库已复制到: ${coreFile.absolutePath}")
            }
            if (!jniFile.exists()) {
                context.assets.open(jniAssetPath).use { input ->
                    FileOutputStream(jniFile).use { output -> input.copyTo(output) }
                }
                Log.i(tag, "✅ JNI 桥接库已复制到: ${jniFile.absolutePath}")
            }

            System.load(coreFile.absolutePath)
            Log.i(tag, "✅ 标准版 ONNX Runtime 核心库加载成功")
            System.load(jniFile.absolutePath)
            Log.i(tag, "✅ ONNX Runtime JNI 桥接库加载成功")

            // 3. 初始化 ONNX Runtime 环境
            ortEnv = OrtEnvironment.getEnvironment()
            val sessionOptions = OrtSession.SessionOptions().apply { setIntraOpNumThreads(2) }
            textEncoderSession = ortEnv?.createSession(textEncoderFile.absolutePath, sessionOptions)
            fmDecoderSession = ortEnv?.createSession(fmDecoderFile.absolutePath, sessionOptions)

            // 4. 加载 PyTorch 声码器
            vocoderModule = Module.load(vocoderFile.absolutePath)

            isInitialized = true
            Log.i(tag, "🎉 ZipVoice TTS 初始化成功")
            initListener?.onSuccess()
        } catch (e: Exception) {
            Log.e(tag, "❌ ZipVoice TTS 初始化失败", e)
            isInitialized = false
            initListener?.onError(e.message ?: "未知错误")
        }
    }

    // ---------- Tokenizer ----------
    private fun textToTokenIds(text: String): IntArray {
        if (text.isBlank()) return intArrayOf()
        // 简单分词:按空格分割(实际应用中需与训练时一致)
        return text.trim().split("\\s+".toRegex()).map { tokenToId[it] ?: tokenToId["<unk>"] ?: 0 }.toIntArray()
    }

    // ---------- 梅尔谱提取 (手动应用汉宁窗) ----------
    fun extractMelFromWav(filePath: String): FloatArray {
        val file = File(filePath)
        val dispatcher = AudioDispatcherFactory.fromFile(file, n_fft, hop_length)
        val samples = mutableListOf<Float>()
        dispatcher.addAudioProcessor(object : AudioProcessor {
            override fun process(audioEvent: AudioEvent): Boolean {
                val buffer = audioEvent.floatBuffer
                for (i in 0 until audioEvent.bufferSize) samples.add(buffer[i])
                return true
            }
            override fun processingFinished() {}
        })
        dispatcher.run()

        val audio = samples.toFloatArray()
        val numFrames = (audio.size - n_fft) / hop_length + 1
        if (numFrames <= 0) return FloatArray(0)

        val melSpectrum = Array(numFrames) { FloatArray(num_mels) }
        val fft = FFT(n_fft, HannWindow())
        val binCount = n_fft / 2 + 1
        val amplitudes = FloatArray(binCount)
        val powerSpectrum = FloatArray(binCount)
        val melFilterBank = createMelFilterBank(samplingRate, n_fft, num_mels)

        // 预计算汉宁窗系数
        val windowCoef = FloatArray(n_fft) { i ->
            (0.5f * (1 - cos(2 * PI * i / (n_fft - 1)))).toFloat()
        }

        for (i in 0 until numFrames) {
            val start = i * hop_length
            val frame = audio.sliceArray(start until start + n_fft)
            for (j in frame.indices) frame[j] *= windowCoef[j]
            fft.forwardTransform(frame)
            fft.modulus(frame, amplitudes)
            for (k in amplitudes.indices) powerSpectrum[k] = amplitudes[k] * amplitudes[k]
            for (mel in 0 until num_mels) {
                var sum = 0.0f
                for (k in 0 until binCount) sum += powerSpectrum[k] * melFilterBank[mel][k]
                melSpectrum[i][mel] = ln(sum + 1e-7f).toFloat()
            }
        }

        val result = FloatArray(numFrames * num_mels)
        for (i in 0 until numFrames) System.arraycopy(melSpectrum[i], 0, result, i * num_mels, num_mels)
        return result
    }

    private fun createMelFilterBank(sampleRate: Int, n_fft: Int, n_mels: Int): Array<FloatArray> {
        val binCount = n_fft / 2 + 1
        val fftFreqs = FloatArray(binCount) { i -> i.toFloat() * sampleRate / n_fft }
        fun hzToMel(hz: Float): Float = 2595f * log10(1f + hz / 700f)
        fun melToHz(mel: Float): Float = 700f * (10f.pow(mel / 2595f) - 1f)
        val minMel = hzToMel(0f); val maxMel = hzToMel(sampleRate / 2f)
        val melPoints = FloatArray(n_mels + 2) { i -> minMel + i.toFloat() / (n_mels + 1) * (maxMel - minMel) }
        val hzPoints = FloatArray(melPoints.size) { i -> melToHz(melPoints[i]) }
        return Array(n_mels) { m ->
            FloatArray(binCount) { k ->
                val freq = fftFreqs[k]
                when {
                    freq in hzPoints[m]..hzPoints[m + 1] -> (freq - hzPoints[m]) / (hzPoints[m + 1] - hzPoints[m])
                    freq in hzPoints[m + 1]..hzPoints[m + 2] -> (hzPoints[m + 2] - freq) / (hzPoints[m + 2] - hzPoints[m + 1])
                    else -> 0f
                }
            }
        }
    }

    // ---------- 时间步生成 ----------
    private fun getTimesteps(numStep: Int, tShift: Float): FloatArray {
        return FloatArray(numStep + 1) { i ->
            val t = i.toFloat() / numStep
            t * tShift / (1 + (tShift - 1) * t)
        }
    }

    // ---------- text_encoder 推理 ----------
    private fun runTextEncoder(tokens: IntArray, promptTokens: IntArray, promptFeaturesLen: Int): FloatArray {
        val env = ortEnv ?: throw IllegalStateException("OrtEnvironment not initialized")
        val session = textEncoderSession ?: throw IllegalStateException("TextEncoder session not initialized")

        val inputs = mutableMapOf<String, OnnxTensor>()
        try {
            inputs["tokens"] = OnnxTensor.createTensor(env, LongBuffer.wrap(tokens.map { it.toLong() }.toLongArray()), longArrayOf(1, tokens.size.toLong()))
            inputs["prompt_tokens"] = OnnxTensor.createTensor(env, LongBuffer.wrap(promptTokens.map { it.toLong() }.toLongArray()), longArrayOf(1, promptTokens.size.toLong()))
            inputs["prompt_features_len"] = OnnxTensor.createTensor(env, LongBuffer.wrap(longArrayOf(promptFeaturesLen.toLong())), longArrayOf(1))
            inputs["speed"] = OnnxTensor.createTensor(env, FloatBuffer.wrap(floatArrayOf(1.0f)), longArrayOf(1))

            val output = session.run(inputs)
            val condition = (output[0] as OnnxTensor).floatBuffer.run { FloatArray(remaining()).also { get(it) } }
            output.close()
            return condition
        } finally {
            inputs.values.forEach { it.close() }
        }
    }

    // ---------- fm_decoder 推理 ----------
    private fun runFmDecoder(
        t: Float,
        x: FloatArray,
        textCondition: FloatArray,
        speechCondition: FloatArray
    ): FloatArray {
        val env = ortEnv ?: throw IllegalStateException("OrtEnvironment not initialized")
        val session = fmDecoderSession ?: throw IllegalStateException("FmDecoder session not initialized")
        val T = x.size / featDim

        val inputs = mutableMapOf<String, OnnxTensor>()
        try {
            inputs["t"] = OnnxTensor.createTensor(env, FloatBuffer.wrap(floatArrayOf(t)), longArrayOf(1))
            inputs["x"] = OnnxTensor.createTensor(env, FloatBuffer.wrap(x), longArrayOf(1, T.toLong(), featDim.toLong()))
            inputs["text_condition"] = OnnxTensor.createTensor(env, FloatBuffer.wrap(textCondition), longArrayOf(1, T.toLong(), featDim.toLong()))
            inputs["speech_condition"] = OnnxTensor.createTensor(env, FloatBuffer.wrap(speechCondition), longArrayOf(1, T.toLong(), featDim.toLong()))
            inputs["guidance_scale"] = OnnxTensor.createTensor(env, FloatBuffer.wrap(floatArrayOf(1.0f)), longArrayOf(1))

            val output = session.run(inputs)
            val v = (output[0] as OnnxTensor).floatBuffer.run { FloatArray(remaining()).also { get(it) } }
            output.close()
            return v
        } finally {
            inputs.values.forEach { it.close() }
        }
    }

    // ---------- 声码器 ----------
    private fun runVocoder(melFeatures: FloatArray): FloatArray {
        val module = vocoderModule ?: throw IllegalStateException("Vocoder module not initialized")
        val T = melFeatures.size / featDim
        val inputTensor = Tensor.fromBlob(melFeatures, longArrayOf(1, featDim.toLong(), T.toLong()))
        val outputIValue = module.forward(IValue.from(inputTensor))
        val outputTensor = outputIValue.toTensor()
        return outputTensor.dataAsFloatArray ?: throw IllegalStateException("Vocoder output is null")
    }

    // ---------- 主合成函数 ----------
    fun synthesize(text: String, promptWavPath: String? = null, promptText: String = ""): FloatArray? {
        if (!isInitialized) { Log.e(tag, "TTS 未初始化"); return null }
        return try {
            val tokens = textToTokenIds(text)
            val promptTokens = if (promptWavPath != null) textToTokenIds(promptText) else intArrayOf()
            val promptFeatures = promptWavPath?.let { extractMelFromWav(it) }
            val promptLen = (promptFeatures?.size ?: 0) / featDim

            val textCondition = runTextEncoder(tokens, promptTokens, promptLen)
            val T = textCondition.size / featDim
            val x = FloatArray(T * featDim) { Random().nextGaussian().toFloat() }

            val speechCondition = FloatArray(T * featDim).apply {
                promptFeatures?.let { feat ->
                    val copyLen = min(feat.size, size)
                    System.arraycopy(feat, 0, this, 0, copyLen)
                }
            }

            val numStep = 16
            val tShift = 0.5f
            val timesteps = getTimesteps(numStep, tShift)
            var currentX = x

            for (step in 0 until numStep) {
                val t = timesteps[step]
                val nextT = timesteps[step + 1]
                val v = runFmDecoder(t, currentX, textCondition, speechCondition)
                val dt = nextT - t
                for (i in currentX.indices) currentX[i] += v[i] * dt
            }

            val predFeatures = if (promptLen > 0) currentX.sliceArray(promptLen * featDim until currentX.size) else currentX
            runVocoder(predFeatures)
        } catch (e: Exception) {
            Log.e(tag, "合成失败", e)
            null
        }
    }

    fun release() {
        textEncoderSession?.close()
        fmDecoderSession?.close()
        ortEnv?.close()
        vocoderModule = null
        tokenToId.clear()
        isInitialized = false
        Log.i(tag, "✅ ZipVoice TTS 资源已释放")
    }
}
Editor is loading...
Leave a Comment