new py

 avatar
lalisadnn
plain_text
10 months ago
2.1 kB
3
Indexable
from flask import Flask, request, jsonify
import numpy as np
import tensorflow as tf
import librosa
import os
from werkzeug.utils import secure_filename

app = Flask(__name__)
model_path = "path_to_your_model.tflite"

# Load TFLite model
interpreter = tf.lite.Interpreter(model_path=model_path)
interpreter.allocate_tensors()

input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

def preprocess_audio(file_path, target_sr):
    audio, sr = librosa.load(file_path, sr=None)
    if sr != target_sr:
        audio = librosa.resample(audio, orig_sr=sr, target_sr=target_sr)
    audio = audio / np.max(np.abs(audio))
    return audio

def run_asr_model(audio_path):
    # Preprocess the audio input
    input_shape = input_details[0]['shape']
    desired_sr = 16000  # Example target sampling rate
    audio = preprocess_audio(audio_path, desired_sr)
    audio = np.resize(audio, input_shape)

    if len(input_shape) == 3:
        audio = np.expand_dims(audio, axis=-1)
    elif len(input_shape) == 2:
        audio = np.expand_dims(audio, axis=0)

    # Set the input tensor
    interpreter.set_tensor(input_details[0]['index'], audio)

    # Run inference
    interpreter.invoke()

    # Get the output tensor
    output_data = interpreter.get_tensor(output_details[0]['index'])

    # Convert the output data to string
    unicode_points = output_data[0]
    recognized_text = ''.join([chr(code_point) for code_point in unicode_points if code_point != 0])

    return recognized_text

@app.route('/transcribe', methods=['POST'])
def transcribe():
    if 'audio' not in request.files:
        return jsonify({'error': 'No audio file provided'}), 400

    audio_file = request.files['audio']
    filename = secure_filename(audio_file.filename)
    file_path = os.path.join('/tmp', filename)
    audio_file.save(file_path)

    transcript = run_asr_model(file_path)
    os.remove(file_path)
    
    return jsonify({'transcript': transcript})

if __name__ == '__main__':
    app.run(host='0.0.0.0', port=5000, debug=True)  # Listen on all network interfaces
Editor is loading...
Leave a Comment