Untitled
unknown
plain_text
a year ago
26 kB
8
Indexable
#onnxleduymanh
//build.gradle
plugins {
alias(libs.plugins.android.application)
id("com.chaquo.python")
}
android {
namespace 'com.sec.android.app.genremusiconnx'
compileSdk 34
defaultConfig {
applicationId "com.sec.android.app.genremusiconnx"
minSdk 32
targetSdk 33
versionCode 1
versionName "1.0"
testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"
sourceSets {
main{
python{
srcDirs = ['src/main/python']
}
}
}
ndk {
// On Apple silicon, you can omit x86_64.
abiFilters "arm64-v8a", "x86_64"
}
python {
version "3.8"
pip {
// A requirement specifier, with or without a version number:
install "scipy"
install "soundfile"
// install "requests==2.24.0"
// install "setuptools==39.1.0"
// install "numpy==1.20.3"
// install "numba==0.51.0"
install "librosa==0.9.2"
install "resampy==0.3.1"
// install "cantools==39.3.0"
// install "bitstruct==8.17.0"
// install "msgpack==1.0.6"
// An sdist or wheel filename, relative to the project directory:
// install "MyPackage-1.2.3-py2.py3-none-any.whl"
// A directory containing a setup.py, relative to the project
// directory (must contain at least one slash):
// install "./MyPackage"
// "-r"` followed by a requirements filename, relative to the
// project directory:
// install "-r", "requirements.txt"
}
}
}
flavorDimensions += "pyVersion"
productFlavors {
create("py38") { dimension = "pyVersion" }
create("py39") { dimension = "pyVersion" }
}
buildTypes {
release {
minifyEnabled false
proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro'
}
}
compileOptions {
sourceCompatibility JavaVersion.VERSION_1_8
targetCompatibility JavaVersion.VERSION_1_8
}
}
chaquopy {
defaultConfig {
pyc {
src = false
}
}
sourceSets { }
productFlavors {
getByName("py38") { version = "3.8" }
}
}
dependencies {
implementation libs.appcompat
implementation libs.material
implementation libs.activity
implementation libs.constraintlayout
testImplementation libs.junit
androidTestImplementation libs.ext.junit
androidTestImplementation libs.espresso.core
implementation libs.onnxruntime.android
implementation libs.onnxruntime.extensions.android
def room_version = "2.6.1"
implementation "androidx.room:room-runtime:$room_version"
implementation 'com.arthenica:ffmpeg-kit-full:6.0-2'
annotationProcessor "androidx.room:room-compiler:$room_version"
}
-ai
--InstrumentClassifier.java
package com.sec.android.app.genremusiconnx.ai;
import static java.lang.Double.MAX_VALUE;
import android.content.Context;
import android.util.Log;
import com.arthenica.ffmpegkit.FFmpegKit;
import com.chaquo.python.PyObject;
import com.chaquo.python.Python;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
public class InstrumentClassifier {
public static final String[] labels = new String[]{"dan_bau","dan_tranh","guitar","violin"};
int [] inputShapes = new int[]{44,13};
private OrtEnvironment ortEnvironment = OrtEnvironment.getEnvironment();
private OrtSession ortSession;
private Context context;
public InstrumentClassifier(Context context) {
OrtSession.SessionOptions sessionOption = new OrtSession.SessionOptions();
try {
this.context = context;
ortSession = ortEnvironment.createSession(readModel(),sessionOption);
} catch (OrtException | IOException e) {
throw new RuntimeException(e);
}
}
private byte[] readModel() throws IOException {
InputStream inputStream = context.getAssets().open("converted_model.onnx");
int size = inputStream.available();
byte [] buffer = new byte[size];
inputStream.read(buffer);
inputStream.close();
return buffer;
}
public void release() {
try {
ortSession.close();
ortEnvironment.close();
} catch (OrtException e) {
throw new RuntimeException(e);
}
}
public int inference(String path){
String outputTemp = path.substring(0, path.indexOf(".wav"))+"_temp.wav";
FFmpegKit.execute("-i "+path+" -ar 22050 -ac 1 "+outputTemp);
Log.i("bacnv", "inference: "+path);
Python python = Python.getInstance();
PyObject pyObject = python.getModule("script");
PyObject result = pyObject.callAttr("main",outputTemp);
List<PyObject> mfccList = result.asList();
double [] flattenData = flattenData(mfccList);
StringBuilder stringBuilder = new StringBuilder();
for (int i = 0; i < flattenData.length; i++) {
stringBuilder.append(flattenData[i]+" ");
}
Log.i("bacnv", stringBuilder.toString());
try {
OnnxTensor onnxTensor = OnnxTensor.createTensor(ortEnvironment, DoubleBuffer.wrap(flattenData),new long[]{1,44,13,1});
Map<String,OnnxTensor> inputModel = new HashMap<>();
inputModel.put("input",onnxTensor);
OrtSession.Result output = ortSession.run(inputModel);
OnnxTensor outputOnnxTensor = (OnnxTensor) output.get(0);
float[] outputArray = outputOnnxTensor.getFloatBuffer().array();
Log.i("bacnv", "inference: "+outputArray.length);
int maxScoreIndex = argmax(outputArray);
Log.i("bacnv", "Result predict - index: "+maxScoreIndex);
Log.i("bacnv", "Result predict: "+labels[maxScoreIndex]);
// new File(outputTemp).delete();
return maxScoreIndex;
} catch (OrtException e) {
throw new RuntimeException(e);
}
}
private int argmax(float[] array) {
int maxIdx = 0;
double maxVal = -MAX_VALUE;
for (int j = 0; j < array.length; j++) {
if (array[j] > maxVal) {
maxVal = array[j];
maxIdx = j;
}
}
return maxIdx;
}
private double[] flattenData(List<PyObject> mfccList) {
int index = 0;
double[] res = new double[inputShapes[0]*inputShapes[1]];
for (int i = 0; i < inputShapes[0] ; i++) {
List<PyObject> mfccListSub = mfccList.get(i).asList();
for (int j = 0; j < inputShapes[1]; j++) {
res[index] = mfccListSub.get(j).toDouble();
index += 1;
}
}
return res;
}
}
-database
-metadata
--RecorderMetadata.java
package com.sec.android.app.genremusiconnx.metadata;
import android.content.Context;
import com.sec.android.app.genremusiconnx.database.AppDatabase;
import com.sec.android.app.genremusiconnx.database.RecordItem;
import com.sec.android.app.genremusiconnx.database.RecordItemDAO;
public class RecorderMetadata {
private static RecorderMetadata mInstance;
private String name;
private long startTime;
private long endTime;
private String path;
private int category;
public RecorderMetadata() {
}
public int getCategory() {
return category;
}
public void setCategory(int category) {
this.category = category;
}
public String getPath() {
return path;
}
public void setPath(String path) {
this.path = path;
}
public void setStartTime(long startTime) {
this.startTime = startTime;
}
public void setEndTime(long endTime) {
this.endTime = endTime;
}
public long getStartTime() {
return startTime;
}
public long getEndTime() {
return endTime;
}
public String getName() {
return name;
}
public long getDuration() {
return endTime - startTime;
}
public void setName(String name) {
this.name = name;
}
public static RecorderMetadata getInstance() {
if (mInstance == null) {
mInstance = new RecorderMetadata();
}
return mInstance;
}
public void release() {
mInstance = null;
}
public void saveToDatabase(Context context) {
RecordItemDAO recordItemDAO = AppDatabase.getInstance(context).recordItemDAO();
recordItemDAO.insert(new RecordItem(
RecorderMetadata.getInstance().getName(),
RecorderMetadata.getInstance().path,
RecorderMetadata.getInstance().getCategory(),
RecorderMetadata.getInstance().getEndTime(),
RecorderMetadata.getInstance().getDuration()));
RecorderMetadata.getInstance().release();
}
}
-ulti
-viewpackage com.sec.android.app.genremusiconnx.view;
import static java.lang.Double.MAX_VALUE;
--RecordFragment.java
package com.sec.android.app.genremusiconnx.view;
import android.media.MediaRecorder;
import android.os.Bundle;
import android.os.Handler;
import android.os.Looper;
import android.util.Log;
import android.view.LayoutInflater;
import android.view.View;
import android.view.ViewGroup;
import android.widget.Button;
import android.widget.FrameLayout;
import android.widget.ImageView;
import android.widget.TextView;
import androidx.annotation.NonNull;
import androidx.annotation.Nullable;
import androidx.fragment.app.Fragment;
import com.sec.android.app.genremusiconnx.R;
import com.sec.android.app.genremusiconnx.ai.InstrumentClassifier;
import com.sec.android.app.genremusiconnx.metadata.RecorderMetadata;
import com.sec.android.app.genremusiconnx.util.BackgroundTask;
import java.text.SimpleDateFormat;
public class RecordFragment extends Fragment implements MediaRecorder.OnInfoListener {
private TextView timeRecord;
private Button recordControl;
private Boolean isRecording;
private ImageView buttonOpenList;
private FrameLayout progressLayout;
public RecordFragment() {
}
@Nullable
@Override
public View onCreateView(@NonNull LayoutInflater inflater, @Nullable ViewGroup container, @Nullable Bundle savedInstanceState) {
View view = inflater.inflate(R.layout.record_fragment, container, false);
initView(view);
// InstrumentClassifier instrumentClassifier = new InstrumentClassifier(getContext());
// instrumentClassifier.inference("/storage/emulated/0/Ringtones/piano_4.wav");
return view;
}
private void initView(View view) {
timeRecord = view.findViewById(R.id.time_record);
recordControl = view.findViewById(R.id.control_record);
buttonOpenList = view.findViewById(R.id.open_list);
isRecording = false;
progressLayout = view.findViewById(R.id.progress_layout);
recordControl.setOnClickListener(v -> handleControlRecord());
buttonOpenList.setOnClickListener(v ->{
if (!isRecording) {
FragmentController.getCurrentInstance().navigateTo(FragmentController.LIST_FRAGMENT);
}
});
}
private void handleControlRecord() {
if (isRecording) {
stopRecord();
isRecording = false;
} else {
startRecord();
isRecording = true;
}
}
private void startRecord() {
recordControl.setText("Stop record");
Recorder.getInstance().startRecord(getContext(), this);
RecorderMetadata.getInstance().setStartTime(System.currentTimeMillis());
Thread timer = new Thread(() -> {
int tick = 0;
while (isRecording) {
int finalTick = tick;
new Handler(Looper.getMainLooper()).post(() -> updateUi(finalTick));
tick += 1000;
try {
Thread.sleep(1000);
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
}
});
timer.start();
}
private void updateUi(int tick) {
SimpleDateFormat simpleDateFormat = new SimpleDateFormat("mm:ss");
timeRecord.setText(simpleDateFormat.format(tick));
}
private void stopRecord() {
isRecording = false;
recordControl.setText("Start record");
Recorder.getInstance().stopRecord();
RecorderMetadata.getInstance().setEndTime(System.currentTimeMillis());
InstrumentClassifier modelClassifier = new InstrumentClassifier(getContext());
final int[] classifiedCategoryId = {-1};
BackgroundTask backgroundTask = new BackgroundTask(new BackgroundTask.BackgroundCallback() {
@Override
public void doInBackground() {
getActivity().runOnUiThread(new Runnable() {
@Override
public void run() {
progressLayout.setVisibility(View.VISIBLE);
}
});
classifiedCategoryId[0] = modelClassifier.inference(RecorderMetadata.getInstance().getPath());
}
@Override
public void onDone() {
getActivity().runOnUiThread(() -> {
progressLayout.setVisibility(View.GONE);
RecorderMetadata.getInstance().setCategory(classifiedCategoryId[0]);
RecorderMetadata.getInstance().saveToDatabase(getContext());
new Handler(Looper.getMainLooper()).postDelayed(()
-> FragmentController.getCurrentInstance().navigateTo(FragmentController.LIST_FRAGMENT),500);
});
}
});
backgroundTask.start();
}
@Override
public void onInfo(MediaRecorder mr, int what, int extra) {
Log.i("bacnv", "onInfo: " + extra);
if (what == MediaRecorder.MEDIA_RECORDER_INFO_MAX_DURATION_REACHED) {
stopRecord();
}
}
}
import android.Manifest;
import android.content.pm.PackageManager;
import android.os.Bundle;
import android.util.Log;
import android.widget.Button;
import android.widget.TextView;
import android.widget.Toast;
import android.window.OnBackInvokedDispatcher;
import androidx.annotation.NonNull;
import androidx.appcompat.app.AppCompatActivity;
import androidx.core.app.ActivityCompat;
import androidx.core.content.ContextCompat;
import androidx.fragment.app.FragmentManager;
import com.chaquo.python.PyObject;
import com.chaquo.python.Python;
import com.chaquo.python.android.AndroidPlatform;
import com.sec.android.app.genremusiconnx.R;
import com.sec.android.app.genremusiconnx.database.AppDatabase;
import com.sec.android.app.genremusiconnx.database.Category;
import com.sec.android.app.genremusiconnx.database.CategoryDAO;
import com.sec.android.app.genremusiconnx.database.RecordItem;
import com.sec.android.app.genremusiconnx.database.RecordItemDAO;
import java.io.IOException;
import java.io.InputStream;
import java.nio.FloatBuffer;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
public class MainActivity extends AppCompatActivity {
public static final int REQUEST_CODE = 1;
private FragmentController fragmentController;
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_main);
fragmentController = FragmentController.getInstance(R.id.layout_container);
fragmentController.setActivity(this);
initDatabase();
if (!Python.isStarted()){
Python.start(new AndroidPlatform(this));
}
checkPermissions(new String[] {Manifest.permission.READ_MEDIA_AUDIO,
Manifest.permission.RECORD_AUDIO, Manifest.permission.WRITE_EXTERNAL_STORAGE},
REQUEST_CODE);
}
private void initDatabase() {
CategoryDAO categoryDAO = AppDatabase.getInstance(this).categoryDAO();
if (categoryDAO.getAllCategory().isEmpty()) {
categoryDAO.insert(
new Category(0,"Dan bau"),
new Category(1,"Dan tranh"),
new Category(2,"Guitar"),
new Category(3,"Violin")
);
}
}
private void checkPermissions(String[] permissions, int requestCode) {
for(int i = 0; i < permissions.length; i++) {
if (ContextCompat.checkSelfPermission(MainActivity.this,permissions[i]) != PackageManager.PERMISSION_GRANTED){
ActivityCompat.requestPermissions(MainActivity.this, permissions, requestCode);
return;
} else {
// Toast.makeText(this,"Permission granted!",Toast.LENGTH_SHORT).show();
}
}
showRecordFragment();
}
@Override
public void onRequestPermissionsResult(int requestCode, @NonNull String[] permissions, @NonNull int[] grantResults) {
super.onRequestPermissionsResult(requestCode, permissions, grantResults);
if(requestCode == REQUEST_CODE){
if(grantResults.length > 0 && grantResults[0] == PackageManager.PERMISSION_GRANTED){
//do something
showRecordFragment();
} else {
Toast.makeText(this,"Permission denied!",Toast.LENGTH_SHORT).show();
}
}
}
private void showRecordFragment() {
fragmentController.navigateTo(FragmentController.RECORD_FRAGMENT);
}
@Override
protected void onResume() {
super.onResume();
}
@Override
protected void onDestroy() {
super.onDestroy();
}
@Override
public void onBackPressed() {
if(!fragmentController.navigateBack()) {
super.onBackPressed();
}
}
}
--MainActivoty.java
...
--Recorder.java
package com.sec.android.app.genremusiconnx.view;
import android.content.Context;
import android.media.MediaFormat;
import android.media.MediaRecorder;
import android.os.Environment;
import android.util.Log;
import com.sec.android.app.genremusiconnx.metadata.RecorderMetadata;
import java.io.File;
import java.io.IOException;
public class Recorder {
private static Recorder mInstance;
private static MediaRecorder mMediaRecorder;
private Recorder() {
}
public static Recorder getInstance() {
if (mInstance == null) {
mInstance = new Recorder();
}
return mInstance;
}
public void startRecord(Context context, MediaRecorder.OnInfoListener listener) {
mMediaRecorder = getMediaRecorder(context);
mMediaRecorder.setOnInfoListener(listener);
mMediaRecorder.start();
}
public void stopRecord() {
if (mMediaRecorder != null) {
mMediaRecorder.stop();
mMediaRecorder.release();
}
}
private MediaRecorder getMediaRecorder(Context context) {
mMediaRecorder = new MediaRecorder(context);
String outputFile = getFileName(context);
RecorderMetadata.getInstance().setPath(outputFile);
try {
mMediaRecorder.setAudioChannels(1);
mMediaRecorder.setAudioSamplingRate(22050);
mMediaRecorder.setAudioSource(MediaRecorder.AudioSource.MIC);
mMediaRecorder.setOutputFormat(MediaRecorder.OutputFormat.THREE_GPP);
mMediaRecorder.setAudioEncoder(MediaRecorder.AudioEncoder.AAC);
mMediaRecorder.setMaxDuration(10000);
mMediaRecorder.setOutputFile(outputFile);
mMediaRecorder.prepare();
} catch (IOException e) {
Log.i("bacnv", "getMediaRecorder: "+e.getMessage());
throw new RuntimeException(e);
}
return mMediaRecorder;
}
private String getFileName(Context context) {
File file = new File(Environment.getExternalStorageDirectory(),"Recordings/GenreMusic");
if (!file.exists()) {
Log.i("bacnv", "startRecord: "+file.getPath());
if (!file.mkdirs()) {
Log.i("bacnv", "can not create folder: ");
}
} else {
Log.i("bacnv", "startRecord: Folder exist!!");
}
String fileName = "recording_"+System.currentTimeMillis();
RecorderMetadata.getInstance().setName(fileName);
return file.getPath()+"/"+fileName+".wav";
}
}
--RecordFragment.java
package com.sec.android.app.genremusiconnx.view;
import android.media.MediaRecorder;
import android.os.Bundle;
import android.os.Handler;
import android.os.Looper;
import android.util.Log;
import android.view.LayoutInflater;
import android.view.View;
import android.view.ViewGroup;
import android.widget.Button;
import android.widget.FrameLayout;
import android.widget.ImageView;
import android.widget.TextView;
import androidx.annotation.NonNull;
import androidx.annotation.Nullable;
import androidx.fragment.app.Fragment;
import com.sec.android.app.genremusiconnx.R;
import com.sec.android.app.genremusiconnx.ai.InstrumentClassifier;
import com.sec.android.app.genremusiconnx.metadata.RecorderMetadata;
import com.sec.android.app.genremusiconnx.util.BackgroundTask;
import java.text.SimpleDateFormat;
public class RecordFragment extends Fragment implements MediaRecorder.OnInfoListener {
private TextView timeRecord;
private Button recordControl;
private Boolean isRecording;
private ImageView buttonOpenList;
private FrameLayout progressLayout;
public RecordFragment() {
}
@Nullable
@Override
public View onCreateView(@NonNull LayoutInflater inflater, @Nullable ViewGroup container, @Nullable Bundle savedInstanceState) {
View view = inflater.inflate(R.layout.record_fragment, container, false);
initView(view);
// InstrumentClassifier instrumentClassifier = new InstrumentClassifier(getContext());
// instrumentClassifier.inference("/storage/emulated/0/Ringtones/piano_4.wav");
return view;
}
private void initView(View view) {
timeRecord = view.findViewById(R.id.time_record);
recordControl = view.findViewById(R.id.control_record);
buttonOpenList = view.findViewById(R.id.open_list);
isRecording = false;
progressLayout = view.findViewById(R.id.progress_layout);
recordControl.setOnClickListener(v -> handleControlRecord());
buttonOpenList.setOnClickListener(v ->{
if (!isRecording) {
FragmentController.getCurrentInstance().navigateTo(FragmentController.LIST_FRAGMENT);
}
});
}
private void handleControlRecord() {
if (isRecording) {
stopRecord();
isRecording = false;
} else {
startRecord();
isRecording = true;
}
}
private void startRecord() {
recordControl.setText("Stop record");
Recorder.getInstance().startRecord(getContext(), this);
RecorderMetadata.getInstance().setStartTime(System.currentTimeMillis());
Thread timer = new Thread(() -> {
int tick = 0;
while (isRecording) {
int finalTick = tick;
new Handler(Looper.getMainLooper()).post(() -> updateUi(finalTick));
tick += 1000;
try {
Thread.sleep(1000);
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
}
});
timer.start();
}
private void updateUi(int tick) {
SimpleDateFormat simpleDateFormat = new SimpleDateFormat("mm:ss");
timeRecord.setText(simpleDateFormat.format(tick));
}
private void stopRecord() {
isRecording = false;
recordControl.setText("Start record");
Recorder.getInstance().stopRecord();
RecorderMetadata.getInstance().setEndTime(System.currentTimeMillis());
InstrumentClassifier modelClassifier = new InstrumentClassifier(getContext());
final int[] classifiedCategoryId = {-1};
BackgroundTask backgroundTask = new BackgroundTask(new BackgroundTask.BackgroundCallback() {
@Override
public void doInBackground() {
getActivity().runOnUiThread(new Runnable() {
@Override
public void run() {
progressLayout.setVisibility(View.VISIBLE);
}
});
classifiedCategoryId[0] = modelClassifier.inference(RecorderMetadata.getInstance().getPath());
}
@Override
public void onDone() {
getActivity().runOnUiThread(() -> {
progressLayout.setVisibility(View.GONE);
RecorderMetadata.getInstance().setCategory(classifiedCategoryId[0]);
RecorderMetadata.getInstance().saveToDatabase(getContext());
new Handler(Looper.getMainLooper()).postDelayed(()
-> FragmentController.getCurrentInstance().navigateTo(FragmentController.LIST_FRAGMENT),500);
});
}
});
backgroundTask.start();
}
@Override
public void onInfo(MediaRecorder mr, int what, int extra) {
Log.i("bacnv", "onInfo: " + extra);
if (what == MediaRecorder.MEDIA_RECORDER_INFO_MAX_DURATION_REACHED) {
stopRecord();
}
}
}
Editor is loading...
Leave a Comment