Untitled

 avatar
unknown
plain_text
5 months ago
26 kB
6
Indexable
#onnxleduymanh

//build.gradle
plugins {
    alias(libs.plugins.android.application)
    id("com.chaquo.python")
}

android {
    namespace 'com.sec.android.app.genremusiconnx'
    compileSdk 34

    defaultConfig {
        applicationId "com.sec.android.app.genremusiconnx"
        minSdk 32
        targetSdk 33
        versionCode 1
        versionName "1.0"

        testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"

        sourceSets {
            main{
                python{
                    srcDirs = ['src/main/python']
                }
            }
        }

        ndk {
            // On Apple silicon, you can omit x86_64.
            abiFilters "arm64-v8a", "x86_64"
        }

        python {
            version "3.8"
            pip {
                // A requirement specifier, with or without a version number:
                install "scipy"
                install "soundfile"
//                install "requests==2.24.0"
//                install "setuptools==39.1.0"
//                  install "numpy==1.20.3"
//                install "numba==0.51.0"
                install "librosa==0.9.2"
                install "resampy==0.3.1"
//                install "cantools==39.3.0"
//                install "bitstruct==8.17.0"
//                install "msgpack==1.0.6"

                // An sdist or wheel filename, relative to the project directory:
                // install "MyPackage-1.2.3-py2.py3-none-any.whl"

                // A directory containing a setup.py, relative to the project
                // directory (must contain at least one slash):
                //   install "./MyPackage"

                // "-r"` followed by a requirements filename, relative to the
                // project directory:
                //    install "-r", "requirements.txt"
            }
        }
    }

    flavorDimensions += "pyVersion"
    productFlavors {
        create("py38") { dimension = "pyVersion" }
        create("py39") { dimension = "pyVersion" }
    }

    buildTypes {
        release {
            minifyEnabled false
            proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro'
        }
    }
    compileOptions {
        sourceCompatibility JavaVersion.VERSION_1_8
        targetCompatibility JavaVersion.VERSION_1_8
    }
}

chaquopy {
    defaultConfig {
        pyc {
            src = false
        }
    }
    sourceSets { }
    productFlavors {
        getByName("py38") { version = "3.8" }
    }
}

dependencies {

    implementation libs.appcompat
    implementation libs.material
    implementation libs.activity
    implementation libs.constraintlayout
    testImplementation libs.junit
    androidTestImplementation libs.ext.junit
    androidTestImplementation libs.espresso.core
    implementation libs.onnxruntime.android
    implementation libs.onnxruntime.extensions.android

    def room_version = "2.6.1"
    implementation "androidx.room:room-runtime:$room_version"
    implementation 'com.arthenica:ffmpeg-kit-full:6.0-2'
    annotationProcessor "androidx.room:room-compiler:$room_version"
}

-ai
--InstrumentClassifier.java

package com.sec.android.app.genremusiconnx.ai;

import static java.lang.Double.MAX_VALUE;

import android.content.Context;
import android.util.Log;

import com.arthenica.ffmpegkit.FFmpegKit;
import com.chaquo.python.PyObject;
import com.chaquo.python.Python;

import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;

public class InstrumentClassifier {

    public static final String[] labels = new String[]{"dan_bau","dan_tranh","guitar","violin"};
    int [] inputShapes = new int[]{44,13};
    private OrtEnvironment ortEnvironment = OrtEnvironment.getEnvironment();
    private OrtSession ortSession;
    private Context context;

    public InstrumentClassifier(Context context) {
        OrtSession.SessionOptions sessionOption = new OrtSession.SessionOptions();
        try {
            this.context = context;
            ortSession = ortEnvironment.createSession(readModel(),sessionOption);
        } catch (OrtException | IOException e) {
            throw new RuntimeException(e);
        }
    }

    private byte[] readModel() throws IOException {
        InputStream inputStream =  context.getAssets().open("converted_model.onnx");
        int size = inputStream.available();
        byte [] buffer = new byte[size];
        inputStream.read(buffer);
        inputStream.close();
        return buffer;
    }

    public void release() {
        try {
            ortSession.close();
            ortEnvironment.close();
        } catch (OrtException e) {
            throw new RuntimeException(e);
        }
    }

    public int inference(String path){
        String outputTemp = path.substring(0, path.indexOf(".wav"))+"_temp.wav";
        FFmpegKit.execute("-i "+path+" -ar 22050 -ac 1 "+outputTemp);

        Log.i("bacnv", "inference: "+path);
        Python python = Python.getInstance();
        PyObject pyObject = python.getModule("script");
        PyObject result = pyObject.callAttr("main",outputTemp);

        List<PyObject> mfccList = result.asList();

        double [] flattenData = flattenData(mfccList);

        StringBuilder stringBuilder = new StringBuilder();
        for (int i = 0; i < flattenData.length; i++) {
            stringBuilder.append(flattenData[i]+" ");
        }
        Log.i("bacnv", stringBuilder.toString());

        try {
            OnnxTensor onnxTensor = OnnxTensor.createTensor(ortEnvironment, DoubleBuffer.wrap(flattenData),new long[]{1,44,13,1});
            Map<String,OnnxTensor> inputModel = new HashMap<>();
            inputModel.put("input",onnxTensor);

            OrtSession.Result output = ortSession.run(inputModel);
            OnnxTensor outputOnnxTensor = (OnnxTensor) output.get(0);
            float[] outputArray = outputOnnxTensor.getFloatBuffer().array();
            Log.i("bacnv", "inference: "+outputArray.length);

            int maxScoreIndex = argmax(outputArray);
            Log.i("bacnv", "Result predict - index: "+maxScoreIndex);
            Log.i("bacnv", "Result predict: "+labels[maxScoreIndex]);
//            new File(outputTemp).delete();
            return maxScoreIndex;
        } catch (OrtException e) {
            throw new RuntimeException(e);
        }
    }

    private int argmax(float[] array) {
        int maxIdx = 0;
        double maxVal = -MAX_VALUE;
        for (int j = 0; j < array.length; j++) {
            if (array[j] > maxVal) {
                maxVal = array[j];
                maxIdx = j;
            }
        }
        return maxIdx;
    }

    private double[] flattenData(List<PyObject> mfccList) {
        int index = 0;
        double[] res = new double[inputShapes[0]*inputShapes[1]];
        for (int i = 0; i < inputShapes[0] ; i++) {
            List<PyObject> mfccListSub = mfccList.get(i).asList();
            for (int j = 0; j < inputShapes[1]; j++) {
                res[index] = mfccListSub.get(j).toDouble();
                index += 1;
            }
        }
        return res;
    }
}

-database
-metadata
--RecorderMetadata.java

package com.sec.android.app.genremusiconnx.metadata;

import android.content.Context;

import com.sec.android.app.genremusiconnx.database.AppDatabase;
import com.sec.android.app.genremusiconnx.database.RecordItem;
import com.sec.android.app.genremusiconnx.database.RecordItemDAO;

public class RecorderMetadata {

    private static RecorderMetadata mInstance;

    private String name;
    private long startTime;
    private long endTime;
    private String path;
    private int category;

    public RecorderMetadata() {

    }

    public int getCategory() {
        return category;
    }

    public void setCategory(int category) {
        this.category = category;
    }

    public String getPath() {
        return path;
    }

    public void setPath(String path) {
        this.path = path;
    }

    public void setStartTime(long startTime) {
        this.startTime = startTime;
    }

    public void setEndTime(long endTime) {
        this.endTime = endTime;
    }

    public long getStartTime() {
        return startTime;
    }

    public long getEndTime() {
        return endTime;
    }

    public String getName() {
        return name;
    }

    public long getDuration() {
        return endTime - startTime;
    }

    public void setName(String name) {
        this.name = name;
    }

    public static RecorderMetadata getInstance() {
        if (mInstance == null) {
            mInstance = new RecorderMetadata();
        }
        return mInstance;
    }

    public void release() {
        mInstance = null;
    }

    public void saveToDatabase(Context context) {
        RecordItemDAO recordItemDAO = AppDatabase.getInstance(context).recordItemDAO();
        recordItemDAO.insert(new RecordItem(
                RecorderMetadata.getInstance().getName(),
                RecorderMetadata.getInstance().path,
                RecorderMetadata.getInstance().getCategory(),
                RecorderMetadata.getInstance().getEndTime(),
                RecorderMetadata.getInstance().getDuration()));
        RecorderMetadata.getInstance().release();
    }
}
-ulti
-viewpackage com.sec.android.app.genremusiconnx.view;

import static java.lang.Double.MAX_VALUE;



--RecordFragment.java

package com.sec.android.app.genremusiconnx.view;

import android.media.MediaRecorder;
import android.os.Bundle;
import android.os.Handler;
import android.os.Looper;
import android.util.Log;
import android.view.LayoutInflater;
import android.view.View;
import android.view.ViewGroup;
import android.widget.Button;
import android.widget.FrameLayout;
import android.widget.ImageView;
import android.widget.TextView;

import androidx.annotation.NonNull;
import androidx.annotation.Nullable;
import androidx.fragment.app.Fragment;

import com.sec.android.app.genremusiconnx.R;
import com.sec.android.app.genremusiconnx.ai.InstrumentClassifier;
import com.sec.android.app.genremusiconnx.metadata.RecorderMetadata;
import com.sec.android.app.genremusiconnx.util.BackgroundTask;

import java.text.SimpleDateFormat;

public class RecordFragment extends Fragment implements MediaRecorder.OnInfoListener {

    private TextView timeRecord;
    private Button recordControl;
    private Boolean isRecording;
    private ImageView buttonOpenList;
    private FrameLayout progressLayout;

    public RecordFragment() {

    }

    @Nullable
    @Override
    public View onCreateView(@NonNull LayoutInflater inflater, @Nullable ViewGroup container, @Nullable Bundle savedInstanceState) {
        View view = inflater.inflate(R.layout.record_fragment, container, false);
        initView(view);
//        InstrumentClassifier instrumentClassifier = new InstrumentClassifier(getContext());
//        instrumentClassifier.inference("/storage/emulated/0/Ringtones/piano_4.wav");
        return view;
    }

    private void initView(View view) {
        timeRecord = view.findViewById(R.id.time_record);
        recordControl = view.findViewById(R.id.control_record);
        buttonOpenList = view.findViewById(R.id.open_list);
        isRecording = false;
        progressLayout = view.findViewById(R.id.progress_layout);

        recordControl.setOnClickListener(v -> handleControlRecord());
        buttonOpenList.setOnClickListener(v ->{
            if (!isRecording) {
                FragmentController.getCurrentInstance().navigateTo(FragmentController.LIST_FRAGMENT);
            }
        });
    }

    private void handleControlRecord() {
        if (isRecording) {
            stopRecord();
            isRecording = false;
        } else {
            startRecord();
            isRecording = true;
        }
    }

    private void startRecord() {
        recordControl.setText("Stop record");
        Recorder.getInstance().startRecord(getContext(), this);
        RecorderMetadata.getInstance().setStartTime(System.currentTimeMillis());
        Thread timer = new Thread(() -> {
            int tick = 0;
            while (isRecording) {
                int finalTick = tick;
                new Handler(Looper.getMainLooper()).post(() -> updateUi(finalTick));
                tick += 1000;
                try {
                    Thread.sleep(1000);
                } catch (InterruptedException e) {
                    throw new RuntimeException(e);
                }
            }
        });
        timer.start();
    }

    private void updateUi(int tick) {
        SimpleDateFormat simpleDateFormat = new SimpleDateFormat("mm:ss");
        timeRecord.setText(simpleDateFormat.format(tick));
    }

    private void stopRecord() {
        isRecording = false;
        recordControl.setText("Start record");
        Recorder.getInstance().stopRecord();
        RecorderMetadata.getInstance().setEndTime(System.currentTimeMillis());
        InstrumentClassifier modelClassifier = new InstrumentClassifier(getContext());

        final int[] classifiedCategoryId = {-1};
        BackgroundTask backgroundTask = new BackgroundTask(new BackgroundTask.BackgroundCallback() {
            @Override
            public void doInBackground() {
                getActivity().runOnUiThread(new Runnable() {
                    @Override
                    public void run() {
                        progressLayout.setVisibility(View.VISIBLE);
                    }
                });
                classifiedCategoryId[0] = modelClassifier.inference(RecorderMetadata.getInstance().getPath());
            }

            @Override
            public void onDone() {
                getActivity().runOnUiThread(() -> {
                    progressLayout.setVisibility(View.GONE);
                    RecorderMetadata.getInstance().setCategory(classifiedCategoryId[0]);
                    RecorderMetadata.getInstance().saveToDatabase(getContext());
                    new Handler(Looper.getMainLooper()).postDelayed(()
                            -> FragmentController.getCurrentInstance().navigateTo(FragmentController.LIST_FRAGMENT),500);
                });
            }
        });
        backgroundTask.start();
    }

    @Override
    public void onInfo(MediaRecorder mr, int what, int extra) {
        Log.i("bacnv", "onInfo: " + extra);
        if (what == MediaRecorder.MEDIA_RECORDER_INFO_MAX_DURATION_REACHED) {
            stopRecord();
        }
    }
}
import android.Manifest;
import android.content.pm.PackageManager;
import android.os.Bundle;
import android.util.Log;
import android.widget.Button;
import android.widget.TextView;
import android.widget.Toast;
import android.window.OnBackInvokedDispatcher;

import androidx.annotation.NonNull;
import androidx.appcompat.app.AppCompatActivity;
import androidx.core.app.ActivityCompat;
import androidx.core.content.ContextCompat;
import androidx.fragment.app.FragmentManager;

import com.chaquo.python.PyObject;
import com.chaquo.python.Python;
import com.chaquo.python.android.AndroidPlatform;
import com.sec.android.app.genremusiconnx.R;
import com.sec.android.app.genremusiconnx.database.AppDatabase;
import com.sec.android.app.genremusiconnx.database.Category;
import com.sec.android.app.genremusiconnx.database.CategoryDAO;
import com.sec.android.app.genremusiconnx.database.RecordItem;
import com.sec.android.app.genremusiconnx.database.RecordItemDAO;

import java.io.IOException;
import java.io.InputStream;
import java.nio.FloatBuffer;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;

public class MainActivity extends AppCompatActivity {

    public static final int REQUEST_CODE = 1;
    private FragmentController fragmentController;


    @Override
    protected void onCreate(Bundle savedInstanceState) {
        super.onCreate(savedInstanceState);
        setContentView(R.layout.activity_main);
        fragmentController = FragmentController.getInstance(R.id.layout_container);
        fragmentController.setActivity(this);

        initDatabase();

        if (!Python.isStarted()){
            Python.start(new AndroidPlatform(this));
        }

        checkPermissions(new String[] {Manifest.permission.READ_MEDIA_AUDIO,
                Manifest.permission.RECORD_AUDIO, Manifest.permission.WRITE_EXTERNAL_STORAGE},
                REQUEST_CODE);
    }

    private void initDatabase() {
        CategoryDAO categoryDAO = AppDatabase.getInstance(this).categoryDAO();
        if (categoryDAO.getAllCategory().isEmpty()) {
            categoryDAO.insert(
                    new Category(0,"Dan bau"),
                    new Category(1,"Dan tranh"),
                    new Category(2,"Guitar"),
                    new Category(3,"Violin")
            );
        }
    }

    private void checkPermissions(String[] permissions, int requestCode) {
        for(int i = 0; i < permissions.length; i++) {
            if (ContextCompat.checkSelfPermission(MainActivity.this,permissions[i]) != PackageManager.PERMISSION_GRANTED){
                ActivityCompat.requestPermissions(MainActivity.this, permissions, requestCode);
                return;
            } else {
//            Toast.makeText(this,"Permission granted!",Toast.LENGTH_SHORT).show();
            }
        }
        showRecordFragment();
    }

    @Override
    public void onRequestPermissionsResult(int requestCode, @NonNull String[] permissions, @NonNull int[] grantResults) {
        super.onRequestPermissionsResult(requestCode, permissions, grantResults);
        if(requestCode == REQUEST_CODE){
            if(grantResults.length > 0 && grantResults[0] == PackageManager.PERMISSION_GRANTED){
                //do something
                showRecordFragment();
            } else {
                Toast.makeText(this,"Permission denied!",Toast.LENGTH_SHORT).show();
            }
        }
    }

    private void showRecordFragment() {
        fragmentController.navigateTo(FragmentController.RECORD_FRAGMENT);
    }

    @Override
    protected void onResume() {
        super.onResume();

    }

    @Override
    protected void onDestroy() {
        super.onDestroy();
    }

    @Override
    public void onBackPressed() {
        if(!fragmentController.navigateBack()) {
            super.onBackPressed();
        }
    }
}

--MainActivoty.java



...
--Recorder.java

package com.sec.android.app.genremusiconnx.view;

import android.content.Context;
import android.media.MediaFormat;
import android.media.MediaRecorder;
import android.os.Environment;
import android.util.Log;

import com.sec.android.app.genremusiconnx.metadata.RecorderMetadata;

import java.io.File;
import java.io.IOException;

public class Recorder {

    private static Recorder mInstance;
    private static MediaRecorder mMediaRecorder;

    private Recorder() {

    }

    public static Recorder getInstance() {
        if (mInstance == null) {
            mInstance = new Recorder();
        }
        return mInstance;
    }

    public void startRecord(Context context, MediaRecorder.OnInfoListener listener) {
        mMediaRecorder = getMediaRecorder(context);
        mMediaRecorder.setOnInfoListener(listener);
        mMediaRecorder.start();
    }

    public void stopRecord() {
        if (mMediaRecorder != null) {
            mMediaRecorder.stop();
            mMediaRecorder.release();
        }
    }

    private MediaRecorder getMediaRecorder(Context context) {
        mMediaRecorder = new MediaRecorder(context);
        String outputFile = getFileName(context);
        RecorderMetadata.getInstance().setPath(outputFile);
        try {
            mMediaRecorder.setAudioChannels(1);
            mMediaRecorder.setAudioSamplingRate(22050);
            mMediaRecorder.setAudioSource(MediaRecorder.AudioSource.MIC);
            mMediaRecorder.setOutputFormat(MediaRecorder.OutputFormat.THREE_GPP);
            mMediaRecorder.setAudioEncoder(MediaRecorder.AudioEncoder.AAC);
            mMediaRecorder.setMaxDuration(10000);
            mMediaRecorder.setOutputFile(outputFile);
            mMediaRecorder.prepare();
        } catch (IOException e) {
            Log.i("bacnv", "getMediaRecorder: "+e.getMessage());
            throw new RuntimeException(e);
        }

        return mMediaRecorder;
    }

    private String getFileName(Context context) {
        File file = new File(Environment.getExternalStorageDirectory(),"Recordings/GenreMusic");
        if (!file.exists()) {
            Log.i("bacnv", "startRecord: "+file.getPath());
            if (!file.mkdirs()) {
                Log.i("bacnv", "can not create folder: ");
            }
        } else {
            Log.i("bacnv", "startRecord: Folder exist!!");
        }
        String fileName = "recording_"+System.currentTimeMillis();
        RecorderMetadata.getInstance().setName(fileName);
        return file.getPath()+"/"+fileName+".wav";
    }
}


--RecordFragment.java

package com.sec.android.app.genremusiconnx.view;

import android.media.MediaRecorder;
import android.os.Bundle;
import android.os.Handler;
import android.os.Looper;
import android.util.Log;
import android.view.LayoutInflater;
import android.view.View;
import android.view.ViewGroup;
import android.widget.Button;
import android.widget.FrameLayout;
import android.widget.ImageView;
import android.widget.TextView;

import androidx.annotation.NonNull;
import androidx.annotation.Nullable;
import androidx.fragment.app.Fragment;

import com.sec.android.app.genremusiconnx.R;
import com.sec.android.app.genremusiconnx.ai.InstrumentClassifier;
import com.sec.android.app.genremusiconnx.metadata.RecorderMetadata;
import com.sec.android.app.genremusiconnx.util.BackgroundTask;

import java.text.SimpleDateFormat;

public class RecordFragment extends Fragment implements MediaRecorder.OnInfoListener {

    private TextView timeRecord;
    private Button recordControl;
    private Boolean isRecording;
    private ImageView buttonOpenList;
    private FrameLayout progressLayout;

    public RecordFragment() {

    }

    @Nullable
    @Override
    public View onCreateView(@NonNull LayoutInflater inflater, @Nullable ViewGroup container, @Nullable Bundle savedInstanceState) {
        View view = inflater.inflate(R.layout.record_fragment, container, false);
        initView(view);
//        InstrumentClassifier instrumentClassifier = new InstrumentClassifier(getContext());
//        instrumentClassifier.inference("/storage/emulated/0/Ringtones/piano_4.wav");
        return view;
    }

    private void initView(View view) {
        timeRecord = view.findViewById(R.id.time_record);
        recordControl = view.findViewById(R.id.control_record);
        buttonOpenList = view.findViewById(R.id.open_list);
        isRecording = false;
        progressLayout = view.findViewById(R.id.progress_layout);

        recordControl.setOnClickListener(v -> handleControlRecord());
        buttonOpenList.setOnClickListener(v ->{
            if (!isRecording) {
                FragmentController.getCurrentInstance().navigateTo(FragmentController.LIST_FRAGMENT);
            }
        });
    }

    private void handleControlRecord() {
        if (isRecording) {
            stopRecord();
            isRecording = false;
        } else {
            startRecord();
            isRecording = true;
        }
    }

    private void startRecord() {
        recordControl.setText("Stop record");
        Recorder.getInstance().startRecord(getContext(), this);
        RecorderMetadata.getInstance().setStartTime(System.currentTimeMillis());
        Thread timer = new Thread(() -> {
            int tick = 0;
            while (isRecording) {
                int finalTick = tick;
                new Handler(Looper.getMainLooper()).post(() -> updateUi(finalTick));
                tick += 1000;
                try {
                    Thread.sleep(1000);
                } catch (InterruptedException e) {
                    throw new RuntimeException(e);
                }
            }
        });
        timer.start();
    }

    private void updateUi(int tick) {
        SimpleDateFormat simpleDateFormat = new SimpleDateFormat("mm:ss");
        timeRecord.setText(simpleDateFormat.format(tick));
    }

    private void stopRecord() {
        isRecording = false;
        recordControl.setText("Start record");
        Recorder.getInstance().stopRecord();
        RecorderMetadata.getInstance().setEndTime(System.currentTimeMillis());
        InstrumentClassifier modelClassifier = new InstrumentClassifier(getContext());

        final int[] classifiedCategoryId = {-1};
        BackgroundTask backgroundTask = new BackgroundTask(new BackgroundTask.BackgroundCallback() {
            @Override
            public void doInBackground() {
                getActivity().runOnUiThread(new Runnable() {
                    @Override
                    public void run() {
                        progressLayout.setVisibility(View.VISIBLE);
                    }
                });
                classifiedCategoryId[0] = modelClassifier.inference(RecorderMetadata.getInstance().getPath());
            }

            @Override
            public void onDone() {
                getActivity().runOnUiThread(() -> {
                    progressLayout.setVisibility(View.GONE);
                    RecorderMetadata.getInstance().setCategory(classifiedCategoryId[0]);
                    RecorderMetadata.getInstance().saveToDatabase(getContext());
                    new Handler(Looper.getMainLooper()).postDelayed(()
                            -> FragmentController.getCurrentInstance().navigateTo(FragmentController.LIST_FRAGMENT),500);
                });
            }
        });
        backgroundTask.start();
    }

    @Override
    public void onInfo(MediaRecorder mr, int what, int extra) {
        Log.i("bacnv", "onInfo: " + extra);
        if (what == MediaRecorder.MEDIA_RECORDER_INFO_MAX_DURATION_REACHED) {
            stopRecord();
        }
    }
}


Editor is loading...
Leave a Comment