ZipVoiceTTSManager.kt
unknown
plain_text
3 months ago
15 kB
8
Indexable
package com.angel.zipvoice
import android.content.Context
import android.util.Log
import ai.onnxruntime.*
import org.pytorch.Module
import org.pytorch.Tensor
import org.pytorch.IValue
import java.io.*
import java.nio.*
import java.util.*
import kotlin.math.*
import be.tarsos.dsp.AudioDispatcher
import be.tarsos.dsp.AudioEvent
import be.tarsos.dsp.AudioProcessor
import be.tarsos.dsp.io.jvm.AudioDispatcherFactory
import be.tarsos.dsp.util.fft.FFT
import be.tarsos.dsp.util.fft.HannWindow
class ZipVoiceTTSManager private constructor(private val context: Context) {
private val tag = "ZipVoiceTTS"
var isInitialized = false; private set
// ONNX Runtime
private var ortEnv: OrtEnvironment? = null
private var textEncoderSession: OrtSession? = null
private var fmDecoderSession: OrtSession? = null
// PyTorch 声码器
private var vocoderModule: Module? = null
// 模型参数
private val featDim = 100
private val samplingRate = 24000
private val n_fft = 1024
private val hop_length = 256
private val num_mels = 100
// 词汇表
private var tokenToId: MutableMap<String, Int> = HashMap()
interface OnInitListener {
fun onSuccess()
fun onError(msg: String)
}
private var initListener: OnInitListener? = null
fun setOnInitListener(listener: OnInitListener) { this.initListener = listener }
companion object {
@Volatile
private var instance: ZipVoiceTTSManager? = null
fun getInstance(context: Context): ZipVoiceTTSManager {
return instance ?: synchronized(this) {
instance ?: ZipVoiceTTSManager(context.applicationContext).also { instance = it }
}
}
fun releaseInstance() {
instance?.release()
instance = null
}
}
init { Thread { initModel() }.start() }
private fun copyAssetToFile(assetPath: String): File {
val destFile = File(context.filesDir, "zipvoice_" + assetPath.replace('/', '_'))
if (!destFile.exists()) {
destFile.parentFile?.mkdirs()
try {
context.assets.open(assetPath).use { input ->
FileOutputStream(destFile).use { output -> input.copyTo(output) }
}
Log.d(tag, "✅ 模型文件已复制: $assetPath")
} catch (e: IOException) {
Log.e(tag, "复制文件失败: $assetPath", e)
}
}
return destFile
}
private fun loadTokens(tokensFile: File) {
tokenToId.clear()
try {
BufferedReader(FileReader(tokensFile)).use { reader ->
var line: String?
while (reader.readLine().also { line = it } != null) {
val parts = line!!.trim().split("\t")
if (parts.size == 2) tokenToId[parts[0]] = parts[1].toInt()
}
}
Log.i(tag, "✅ 加载了 ${tokenToId.size} 个 token")
} catch (e: Exception) {
Log.e(tag, "加载 tokens.txt 失败", e)
}
}
private fun initModel() {
try {
Log.i(tag, "========== ZipVoice TTS 初始化开始 ==========")
// 1. 复制 ONNX 模型、tokens、声码器
val textEncoderFile = copyAssetToFile("models/tts/zipvoice/text_encoder.onnx")
val fmDecoderFile = copyAssetToFile("models/tts/zipvoice/fm_decoder.onnx")
val tokensFile = copyAssetToFile("models/tts/zipvoice/tokens.txt")
val vocoderFile = copyAssetToFile("models/tts/zipvoice/vocoder.pt")
loadTokens(tokensFile)
// 2. 复制并加载标准版 ONNX Runtime 核心库和 JNI 桥接库
val coreAssetPath = "libs/arm64-v8a/libonnxruntime.so"
val jniAssetPath = "libs/arm64-v8a/libonnxruntime4j_jni.so"
val coreFile = File(context.filesDir, "libonnxruntime.so")
val jniFile = File(context.filesDir, "libonnxruntime4j_jni.so")
if (!coreFile.exists()) {
context.assets.open(coreAssetPath).use { input ->
FileOutputStream(coreFile).use { output -> input.copyTo(output) }
}
Log.i(tag, "✅ 标准版核心库已复制到: ${coreFile.absolutePath}")
}
if (!jniFile.exists()) {
context.assets.open(jniAssetPath).use { input ->
FileOutputStream(jniFile).use { output -> input.copyTo(output) }
}
Log.i(tag, "✅ JNI 桥接库已复制到: ${jniFile.absolutePath}")
}
System.load(coreFile.absolutePath)
Log.i(tag, "✅ 标准版 ONNX Runtime 核心库加载成功")
System.load(jniFile.absolutePath)
Log.i(tag, "✅ ONNX Runtime JNI 桥接库加载成功")
// 3. 初始化 ONNX Runtime 环境
ortEnv = OrtEnvironment.getEnvironment()
val sessionOptions = OrtSession.SessionOptions().apply { setIntraOpNumThreads(2) }
textEncoderSession = ortEnv?.createSession(textEncoderFile.absolutePath, sessionOptions)
fmDecoderSession = ortEnv?.createSession(fmDecoderFile.absolutePath, sessionOptions)
// 4. 加载 PyTorch 声码器
vocoderModule = Module.load(vocoderFile.absolutePath)
isInitialized = true
Log.i(tag, "🎉 ZipVoice TTS 初始化成功")
initListener?.onSuccess()
} catch (e: Exception) {
Log.e(tag, "❌ ZipVoice TTS 初始化失败", e)
isInitialized = false
initListener?.onError(e.message ?: "未知错误")
}
}
// ---------- Tokenizer ----------
private fun textToTokenIds(text: String): IntArray {
if (text.isBlank()) return intArrayOf()
// 简单分词:按空格分割(实际应用中需与训练时一致)
return text.trim().split("\\s+".toRegex()).map { tokenToId[it] ?: tokenToId["<unk>"] ?: 0 }.toIntArray()
}
// ---------- 梅尔谱提取 (手动应用汉宁窗) ----------
fun extractMelFromWav(filePath: String): FloatArray {
val file = File(filePath)
val dispatcher = AudioDispatcherFactory.fromFile(file, n_fft, hop_length)
val samples = mutableListOf<Float>()
dispatcher.addAudioProcessor(object : AudioProcessor {
override fun process(audioEvent: AudioEvent): Boolean {
val buffer = audioEvent.floatBuffer
for (i in 0 until audioEvent.bufferSize) samples.add(buffer[i])
return true
}
override fun processingFinished() {}
})
dispatcher.run()
val audio = samples.toFloatArray()
val numFrames = (audio.size - n_fft) / hop_length + 1
if (numFrames <= 0) return FloatArray(0)
val melSpectrum = Array(numFrames) { FloatArray(num_mels) }
val fft = FFT(n_fft, HannWindow())
val binCount = n_fft / 2 + 1
val amplitudes = FloatArray(binCount)
val powerSpectrum = FloatArray(binCount)
val melFilterBank = createMelFilterBank(samplingRate, n_fft, num_mels)
// 预计算汉宁窗系数
val windowCoef = FloatArray(n_fft) { i ->
(0.5f * (1 - cos(2 * PI * i / (n_fft - 1)))).toFloat()
}
for (i in 0 until numFrames) {
val start = i * hop_length
val frame = audio.sliceArray(start until start + n_fft)
for (j in frame.indices) frame[j] *= windowCoef[j]
fft.forwardTransform(frame)
fft.modulus(frame, amplitudes)
for (k in amplitudes.indices) powerSpectrum[k] = amplitudes[k] * amplitudes[k]
for (mel in 0 until num_mels) {
var sum = 0.0f
for (k in 0 until binCount) sum += powerSpectrum[k] * melFilterBank[mel][k]
melSpectrum[i][mel] = ln(sum + 1e-7f).toFloat()
}
}
val result = FloatArray(numFrames * num_mels)
for (i in 0 until numFrames) System.arraycopy(melSpectrum[i], 0, result, i * num_mels, num_mels)
return result
}
private fun createMelFilterBank(sampleRate: Int, n_fft: Int, n_mels: Int): Array<FloatArray> {
val binCount = n_fft / 2 + 1
val fftFreqs = FloatArray(binCount) { i -> i.toFloat() * sampleRate / n_fft }
fun hzToMel(hz: Float): Float = 2595f * log10(1f + hz / 700f)
fun melToHz(mel: Float): Float = 700f * (10f.pow(mel / 2595f) - 1f)
val minMel = hzToMel(0f); val maxMel = hzToMel(sampleRate / 2f)
val melPoints = FloatArray(n_mels + 2) { i -> minMel + i.toFloat() / (n_mels + 1) * (maxMel - minMel) }
val hzPoints = FloatArray(melPoints.size) { i -> melToHz(melPoints[i]) }
return Array(n_mels) { m ->
FloatArray(binCount) { k ->
val freq = fftFreqs[k]
when {
freq in hzPoints[m]..hzPoints[m + 1] -> (freq - hzPoints[m]) / (hzPoints[m + 1] - hzPoints[m])
freq in hzPoints[m + 1]..hzPoints[m + 2] -> (hzPoints[m + 2] - freq) / (hzPoints[m + 2] - hzPoints[m + 1])
else -> 0f
}
}
}
}
// ---------- 时间步生成 ----------
private fun getTimesteps(numStep: Int, tShift: Float): FloatArray {
return FloatArray(numStep + 1) { i ->
val t = i.toFloat() / numStep
t * tShift / (1 + (tShift - 1) * t)
}
}
// ---------- text_encoder 推理 ----------
private fun runTextEncoder(tokens: IntArray, promptTokens: IntArray, promptFeaturesLen: Int): FloatArray {
val env = ortEnv ?: throw IllegalStateException("OrtEnvironment not initialized")
val session = textEncoderSession ?: throw IllegalStateException("TextEncoder session not initialized")
val inputs = mutableMapOf<String, OnnxTensor>()
try {
inputs["tokens"] = OnnxTensor.createTensor(env, LongBuffer.wrap(tokens.map { it.toLong() }.toLongArray()), longArrayOf(1, tokens.size.toLong()))
inputs["prompt_tokens"] = OnnxTensor.createTensor(env, LongBuffer.wrap(promptTokens.map { it.toLong() }.toLongArray()), longArrayOf(1, promptTokens.size.toLong()))
inputs["prompt_features_len"] = OnnxTensor.createTensor(env, LongBuffer.wrap(longArrayOf(promptFeaturesLen.toLong())), longArrayOf(1))
inputs["speed"] = OnnxTensor.createTensor(env, FloatBuffer.wrap(floatArrayOf(1.0f)), longArrayOf(1))
val output = session.run(inputs)
val condition = (output[0] as OnnxTensor).floatBuffer.run { FloatArray(remaining()).also { get(it) } }
output.close()
return condition
} finally {
inputs.values.forEach { it.close() }
}
}
// ---------- fm_decoder 推理 ----------
private fun runFmDecoder(
t: Float,
x: FloatArray,
textCondition: FloatArray,
speechCondition: FloatArray
): FloatArray {
val env = ortEnv ?: throw IllegalStateException("OrtEnvironment not initialized")
val session = fmDecoderSession ?: throw IllegalStateException("FmDecoder session not initialized")
val T = x.size / featDim
val inputs = mutableMapOf<String, OnnxTensor>()
try {
inputs["t"] = OnnxTensor.createTensor(env, FloatBuffer.wrap(floatArrayOf(t)), longArrayOf(1))
inputs["x"] = OnnxTensor.createTensor(env, FloatBuffer.wrap(x), longArrayOf(1, T.toLong(), featDim.toLong()))
inputs["text_condition"] = OnnxTensor.createTensor(env, FloatBuffer.wrap(textCondition), longArrayOf(1, T.toLong(), featDim.toLong()))
inputs["speech_condition"] = OnnxTensor.createTensor(env, FloatBuffer.wrap(speechCondition), longArrayOf(1, T.toLong(), featDim.toLong()))
inputs["guidance_scale"] = OnnxTensor.createTensor(env, FloatBuffer.wrap(floatArrayOf(1.0f)), longArrayOf(1))
val output = session.run(inputs)
val v = (output[0] as OnnxTensor).floatBuffer.run { FloatArray(remaining()).also { get(it) } }
output.close()
return v
} finally {
inputs.values.forEach { it.close() }
}
}
// ---------- 声码器 ----------
private fun runVocoder(melFeatures: FloatArray): FloatArray {
val module = vocoderModule ?: throw IllegalStateException("Vocoder module not initialized")
val T = melFeatures.size / featDim
val inputTensor = Tensor.fromBlob(melFeatures, longArrayOf(1, featDim.toLong(), T.toLong()))
val outputIValue = module.forward(IValue.from(inputTensor))
val outputTensor = outputIValue.toTensor()
return outputTensor.dataAsFloatArray ?: throw IllegalStateException("Vocoder output is null")
}
// ---------- 主合成函数 ----------
fun synthesize(text: String, promptWavPath: String? = null, promptText: String = ""): FloatArray? {
if (!isInitialized) { Log.e(tag, "TTS 未初始化"); return null }
return try {
val tokens = textToTokenIds(text)
val promptTokens = if (promptWavPath != null) textToTokenIds(promptText) else intArrayOf()
val promptFeatures = promptWavPath?.let { extractMelFromWav(it) }
val promptLen = (promptFeatures?.size ?: 0) / featDim
val textCondition = runTextEncoder(tokens, promptTokens, promptLen)
val T = textCondition.size / featDim
val x = FloatArray(T * featDim) { Random().nextGaussian().toFloat() }
val speechCondition = FloatArray(T * featDim).apply {
promptFeatures?.let { feat ->
val copyLen = min(feat.size, size)
System.arraycopy(feat, 0, this, 0, copyLen)
}
}
val numStep = 16
val tShift = 0.5f
val timesteps = getTimesteps(numStep, tShift)
var currentX = x
for (step in 0 until numStep) {
val t = timesteps[step]
val nextT = timesteps[step + 1]
val v = runFmDecoder(t, currentX, textCondition, speechCondition)
val dt = nextT - t
for (i in currentX.indices) currentX[i] += v[i] * dt
}
val predFeatures = if (promptLen > 0) currentX.sliceArray(promptLen * featDim until currentX.size) else currentX
runVocoder(predFeatures)
} catch (e: Exception) {
Log.e(tag, "合成失败", e)
null
}
}
fun release() {
textEncoderSession?.close()
fmDecoderSession?.close()
ortEnv?.close()
vocoderModule = null
tokenToId.clear()
isInitialized = false
Log.i(tag, "✅ ZipVoice TTS 资源已释放")
}
}Editor is loading...
Leave a Comment