Untitled
unknown
plain_text
5 months ago
26 kB
6
Indexable
#onnxleduymanh //build.gradle plugins { alias(libs.plugins.android.application) id("com.chaquo.python") } android { namespace 'com.sec.android.app.genremusiconnx' compileSdk 34 defaultConfig { applicationId "com.sec.android.app.genremusiconnx" minSdk 32 targetSdk 33 versionCode 1 versionName "1.0" testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner" sourceSets { main{ python{ srcDirs = ['src/main/python'] } } } ndk { // On Apple silicon, you can omit x86_64. abiFilters "arm64-v8a", "x86_64" } python { version "3.8" pip { // A requirement specifier, with or without a version number: install "scipy" install "soundfile" // install "requests==2.24.0" // install "setuptools==39.1.0" // install "numpy==1.20.3" // install "numba==0.51.0" install "librosa==0.9.2" install "resampy==0.3.1" // install "cantools==39.3.0" // install "bitstruct==8.17.0" // install "msgpack==1.0.6" // An sdist or wheel filename, relative to the project directory: // install "MyPackage-1.2.3-py2.py3-none-any.whl" // A directory containing a setup.py, relative to the project // directory (must contain at least one slash): // install "./MyPackage" // "-r"` followed by a requirements filename, relative to the // project directory: // install "-r", "requirements.txt" } } } flavorDimensions += "pyVersion" productFlavors { create("py38") { dimension = "pyVersion" } create("py39") { dimension = "pyVersion" } } buildTypes { release { minifyEnabled false proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro' } } compileOptions { sourceCompatibility JavaVersion.VERSION_1_8 targetCompatibility JavaVersion.VERSION_1_8 } } chaquopy { defaultConfig { pyc { src = false } } sourceSets { } productFlavors { getByName("py38") { version = "3.8" } } } dependencies { implementation libs.appcompat implementation libs.material implementation libs.activity implementation libs.constraintlayout testImplementation libs.junit androidTestImplementation libs.ext.junit androidTestImplementation libs.espresso.core implementation libs.onnxruntime.android implementation libs.onnxruntime.extensions.android def room_version = "2.6.1" implementation "androidx.room:room-runtime:$room_version" implementation 'com.arthenica:ffmpeg-kit-full:6.0-2' annotationProcessor "androidx.room:room-compiler:$room_version" } -ai --InstrumentClassifier.java package com.sec.android.app.genremusiconnx.ai; import static java.lang.Double.MAX_VALUE; import android.content.Context; import android.util.Log; import com.arthenica.ffmpegkit.FFmpegKit; import com.chaquo.python.PyObject; import com.chaquo.python.Python; import java.io.File; import java.io.IOException; import java.io.InputStream; import java.nio.DoubleBuffer; import java.nio.FloatBuffer; import java.util.HashMap; import java.util.List; import java.util.Map; import ai.onnxruntime.OnnxTensor; import ai.onnxruntime.OrtEnvironment; import ai.onnxruntime.OrtException; import ai.onnxruntime.OrtSession; public class InstrumentClassifier { public static final String[] labels = new String[]{"dan_bau","dan_tranh","guitar","violin"}; int [] inputShapes = new int[]{44,13}; private OrtEnvironment ortEnvironment = OrtEnvironment.getEnvironment(); private OrtSession ortSession; private Context context; public InstrumentClassifier(Context context) { OrtSession.SessionOptions sessionOption = new OrtSession.SessionOptions(); try { this.context = context; ortSession = ortEnvironment.createSession(readModel(),sessionOption); } catch (OrtException | IOException e) { throw new RuntimeException(e); } } private byte[] readModel() throws IOException { InputStream inputStream = context.getAssets().open("converted_model.onnx"); int size = inputStream.available(); byte [] buffer = new byte[size]; inputStream.read(buffer); inputStream.close(); return buffer; } public void release() { try { ortSession.close(); ortEnvironment.close(); } catch (OrtException e) { throw new RuntimeException(e); } } public int inference(String path){ String outputTemp = path.substring(0, path.indexOf(".wav"))+"_temp.wav"; FFmpegKit.execute("-i "+path+" -ar 22050 -ac 1 "+outputTemp); Log.i("bacnv", "inference: "+path); Python python = Python.getInstance(); PyObject pyObject = python.getModule("script"); PyObject result = pyObject.callAttr("main",outputTemp); List<PyObject> mfccList = result.asList(); double [] flattenData = flattenData(mfccList); StringBuilder stringBuilder = new StringBuilder(); for (int i = 0; i < flattenData.length; i++) { stringBuilder.append(flattenData[i]+" "); } Log.i("bacnv", stringBuilder.toString()); try { OnnxTensor onnxTensor = OnnxTensor.createTensor(ortEnvironment, DoubleBuffer.wrap(flattenData),new long[]{1,44,13,1}); Map<String,OnnxTensor> inputModel = new HashMap<>(); inputModel.put("input",onnxTensor); OrtSession.Result output = ortSession.run(inputModel); OnnxTensor outputOnnxTensor = (OnnxTensor) output.get(0); float[] outputArray = outputOnnxTensor.getFloatBuffer().array(); Log.i("bacnv", "inference: "+outputArray.length); int maxScoreIndex = argmax(outputArray); Log.i("bacnv", "Result predict - index: "+maxScoreIndex); Log.i("bacnv", "Result predict: "+labels[maxScoreIndex]); // new File(outputTemp).delete(); return maxScoreIndex; } catch (OrtException e) { throw new RuntimeException(e); } } private int argmax(float[] array) { int maxIdx = 0; double maxVal = -MAX_VALUE; for (int j = 0; j < array.length; j++) { if (array[j] > maxVal) { maxVal = array[j]; maxIdx = j; } } return maxIdx; } private double[] flattenData(List<PyObject> mfccList) { int index = 0; double[] res = new double[inputShapes[0]*inputShapes[1]]; for (int i = 0; i < inputShapes[0] ; i++) { List<PyObject> mfccListSub = mfccList.get(i).asList(); for (int j = 0; j < inputShapes[1]; j++) { res[index] = mfccListSub.get(j).toDouble(); index += 1; } } return res; } } -database -metadata --RecorderMetadata.java package com.sec.android.app.genremusiconnx.metadata; import android.content.Context; import com.sec.android.app.genremusiconnx.database.AppDatabase; import com.sec.android.app.genremusiconnx.database.RecordItem; import com.sec.android.app.genremusiconnx.database.RecordItemDAO; public class RecorderMetadata { private static RecorderMetadata mInstance; private String name; private long startTime; private long endTime; private String path; private int category; public RecorderMetadata() { } public int getCategory() { return category; } public void setCategory(int category) { this.category = category; } public String getPath() { return path; } public void setPath(String path) { this.path = path; } public void setStartTime(long startTime) { this.startTime = startTime; } public void setEndTime(long endTime) { this.endTime = endTime; } public long getStartTime() { return startTime; } public long getEndTime() { return endTime; } public String getName() { return name; } public long getDuration() { return endTime - startTime; } public void setName(String name) { this.name = name; } public static RecorderMetadata getInstance() { if (mInstance == null) { mInstance = new RecorderMetadata(); } return mInstance; } public void release() { mInstance = null; } public void saveToDatabase(Context context) { RecordItemDAO recordItemDAO = AppDatabase.getInstance(context).recordItemDAO(); recordItemDAO.insert(new RecordItem( RecorderMetadata.getInstance().getName(), RecorderMetadata.getInstance().path, RecorderMetadata.getInstance().getCategory(), RecorderMetadata.getInstance().getEndTime(), RecorderMetadata.getInstance().getDuration())); RecorderMetadata.getInstance().release(); } } -ulti -viewpackage com.sec.android.app.genremusiconnx.view; import static java.lang.Double.MAX_VALUE; --RecordFragment.java package com.sec.android.app.genremusiconnx.view; import android.media.MediaRecorder; import android.os.Bundle; import android.os.Handler; import android.os.Looper; import android.util.Log; import android.view.LayoutInflater; import android.view.View; import android.view.ViewGroup; import android.widget.Button; import android.widget.FrameLayout; import android.widget.ImageView; import android.widget.TextView; import androidx.annotation.NonNull; import androidx.annotation.Nullable; import androidx.fragment.app.Fragment; import com.sec.android.app.genremusiconnx.R; import com.sec.android.app.genremusiconnx.ai.InstrumentClassifier; import com.sec.android.app.genremusiconnx.metadata.RecorderMetadata; import com.sec.android.app.genremusiconnx.util.BackgroundTask; import java.text.SimpleDateFormat; public class RecordFragment extends Fragment implements MediaRecorder.OnInfoListener { private TextView timeRecord; private Button recordControl; private Boolean isRecording; private ImageView buttonOpenList; private FrameLayout progressLayout; public RecordFragment() { } @Nullable @Override public View onCreateView(@NonNull LayoutInflater inflater, @Nullable ViewGroup container, @Nullable Bundle savedInstanceState) { View view = inflater.inflate(R.layout.record_fragment, container, false); initView(view); // InstrumentClassifier instrumentClassifier = new InstrumentClassifier(getContext()); // instrumentClassifier.inference("/storage/emulated/0/Ringtones/piano_4.wav"); return view; } private void initView(View view) { timeRecord = view.findViewById(R.id.time_record); recordControl = view.findViewById(R.id.control_record); buttonOpenList = view.findViewById(R.id.open_list); isRecording = false; progressLayout = view.findViewById(R.id.progress_layout); recordControl.setOnClickListener(v -> handleControlRecord()); buttonOpenList.setOnClickListener(v ->{ if (!isRecording) { FragmentController.getCurrentInstance().navigateTo(FragmentController.LIST_FRAGMENT); } }); } private void handleControlRecord() { if (isRecording) { stopRecord(); isRecording = false; } else { startRecord(); isRecording = true; } } private void startRecord() { recordControl.setText("Stop record"); Recorder.getInstance().startRecord(getContext(), this); RecorderMetadata.getInstance().setStartTime(System.currentTimeMillis()); Thread timer = new Thread(() -> { int tick = 0; while (isRecording) { int finalTick = tick; new Handler(Looper.getMainLooper()).post(() -> updateUi(finalTick)); tick += 1000; try { Thread.sleep(1000); } catch (InterruptedException e) { throw new RuntimeException(e); } } }); timer.start(); } private void updateUi(int tick) { SimpleDateFormat simpleDateFormat = new SimpleDateFormat("mm:ss"); timeRecord.setText(simpleDateFormat.format(tick)); } private void stopRecord() { isRecording = false; recordControl.setText("Start record"); Recorder.getInstance().stopRecord(); RecorderMetadata.getInstance().setEndTime(System.currentTimeMillis()); InstrumentClassifier modelClassifier = new InstrumentClassifier(getContext()); final int[] classifiedCategoryId = {-1}; BackgroundTask backgroundTask = new BackgroundTask(new BackgroundTask.BackgroundCallback() { @Override public void doInBackground() { getActivity().runOnUiThread(new Runnable() { @Override public void run() { progressLayout.setVisibility(View.VISIBLE); } }); classifiedCategoryId[0] = modelClassifier.inference(RecorderMetadata.getInstance().getPath()); } @Override public void onDone() { getActivity().runOnUiThread(() -> { progressLayout.setVisibility(View.GONE); RecorderMetadata.getInstance().setCategory(classifiedCategoryId[0]); RecorderMetadata.getInstance().saveToDatabase(getContext()); new Handler(Looper.getMainLooper()).postDelayed(() -> FragmentController.getCurrentInstance().navigateTo(FragmentController.LIST_FRAGMENT),500); }); } }); backgroundTask.start(); } @Override public void onInfo(MediaRecorder mr, int what, int extra) { Log.i("bacnv", "onInfo: " + extra); if (what == MediaRecorder.MEDIA_RECORDER_INFO_MAX_DURATION_REACHED) { stopRecord(); } } } import android.Manifest; import android.content.pm.PackageManager; import android.os.Bundle; import android.util.Log; import android.widget.Button; import android.widget.TextView; import android.widget.Toast; import android.window.OnBackInvokedDispatcher; import androidx.annotation.NonNull; import androidx.appcompat.app.AppCompatActivity; import androidx.core.app.ActivityCompat; import androidx.core.content.ContextCompat; import androidx.fragment.app.FragmentManager; import com.chaquo.python.PyObject; import com.chaquo.python.Python; import com.chaquo.python.android.AndroidPlatform; import com.sec.android.app.genremusiconnx.R; import com.sec.android.app.genremusiconnx.database.AppDatabase; import com.sec.android.app.genremusiconnx.database.Category; import com.sec.android.app.genremusiconnx.database.CategoryDAO; import com.sec.android.app.genremusiconnx.database.RecordItem; import com.sec.android.app.genremusiconnx.database.RecordItemDAO; import java.io.IOException; import java.io.InputStream; import java.nio.FloatBuffer; import java.util.HashMap; import java.util.List; import java.util.Map; import ai.onnxruntime.OnnxTensor; import ai.onnxruntime.OrtEnvironment; import ai.onnxruntime.OrtException; import ai.onnxruntime.OrtSession; public class MainActivity extends AppCompatActivity { public static final int REQUEST_CODE = 1; private FragmentController fragmentController; @Override protected void onCreate(Bundle savedInstanceState) { super.onCreate(savedInstanceState); setContentView(R.layout.activity_main); fragmentController = FragmentController.getInstance(R.id.layout_container); fragmentController.setActivity(this); initDatabase(); if (!Python.isStarted()){ Python.start(new AndroidPlatform(this)); } checkPermissions(new String[] {Manifest.permission.READ_MEDIA_AUDIO, Manifest.permission.RECORD_AUDIO, Manifest.permission.WRITE_EXTERNAL_STORAGE}, REQUEST_CODE); } private void initDatabase() { CategoryDAO categoryDAO = AppDatabase.getInstance(this).categoryDAO(); if (categoryDAO.getAllCategory().isEmpty()) { categoryDAO.insert( new Category(0,"Dan bau"), new Category(1,"Dan tranh"), new Category(2,"Guitar"), new Category(3,"Violin") ); } } private void checkPermissions(String[] permissions, int requestCode) { for(int i = 0; i < permissions.length; i++) { if (ContextCompat.checkSelfPermission(MainActivity.this,permissions[i]) != PackageManager.PERMISSION_GRANTED){ ActivityCompat.requestPermissions(MainActivity.this, permissions, requestCode); return; } else { // Toast.makeText(this,"Permission granted!",Toast.LENGTH_SHORT).show(); } } showRecordFragment(); } @Override public void onRequestPermissionsResult(int requestCode, @NonNull String[] permissions, @NonNull int[] grantResults) { super.onRequestPermissionsResult(requestCode, permissions, grantResults); if(requestCode == REQUEST_CODE){ if(grantResults.length > 0 && grantResults[0] == PackageManager.PERMISSION_GRANTED){ //do something showRecordFragment(); } else { Toast.makeText(this,"Permission denied!",Toast.LENGTH_SHORT).show(); } } } private void showRecordFragment() { fragmentController.navigateTo(FragmentController.RECORD_FRAGMENT); } @Override protected void onResume() { super.onResume(); } @Override protected void onDestroy() { super.onDestroy(); } @Override public void onBackPressed() { if(!fragmentController.navigateBack()) { super.onBackPressed(); } } } --MainActivoty.java ... --Recorder.java package com.sec.android.app.genremusiconnx.view; import android.content.Context; import android.media.MediaFormat; import android.media.MediaRecorder; import android.os.Environment; import android.util.Log; import com.sec.android.app.genremusiconnx.metadata.RecorderMetadata; import java.io.File; import java.io.IOException; public class Recorder { private static Recorder mInstance; private static MediaRecorder mMediaRecorder; private Recorder() { } public static Recorder getInstance() { if (mInstance == null) { mInstance = new Recorder(); } return mInstance; } public void startRecord(Context context, MediaRecorder.OnInfoListener listener) { mMediaRecorder = getMediaRecorder(context); mMediaRecorder.setOnInfoListener(listener); mMediaRecorder.start(); } public void stopRecord() { if (mMediaRecorder != null) { mMediaRecorder.stop(); mMediaRecorder.release(); } } private MediaRecorder getMediaRecorder(Context context) { mMediaRecorder = new MediaRecorder(context); String outputFile = getFileName(context); RecorderMetadata.getInstance().setPath(outputFile); try { mMediaRecorder.setAudioChannels(1); mMediaRecorder.setAudioSamplingRate(22050); mMediaRecorder.setAudioSource(MediaRecorder.AudioSource.MIC); mMediaRecorder.setOutputFormat(MediaRecorder.OutputFormat.THREE_GPP); mMediaRecorder.setAudioEncoder(MediaRecorder.AudioEncoder.AAC); mMediaRecorder.setMaxDuration(10000); mMediaRecorder.setOutputFile(outputFile); mMediaRecorder.prepare(); } catch (IOException e) { Log.i("bacnv", "getMediaRecorder: "+e.getMessage()); throw new RuntimeException(e); } return mMediaRecorder; } private String getFileName(Context context) { File file = new File(Environment.getExternalStorageDirectory(),"Recordings/GenreMusic"); if (!file.exists()) { Log.i("bacnv", "startRecord: "+file.getPath()); if (!file.mkdirs()) { Log.i("bacnv", "can not create folder: "); } } else { Log.i("bacnv", "startRecord: Folder exist!!"); } String fileName = "recording_"+System.currentTimeMillis(); RecorderMetadata.getInstance().setName(fileName); return file.getPath()+"/"+fileName+".wav"; } } --RecordFragment.java package com.sec.android.app.genremusiconnx.view; import android.media.MediaRecorder; import android.os.Bundle; import android.os.Handler; import android.os.Looper; import android.util.Log; import android.view.LayoutInflater; import android.view.View; import android.view.ViewGroup; import android.widget.Button; import android.widget.FrameLayout; import android.widget.ImageView; import android.widget.TextView; import androidx.annotation.NonNull; import androidx.annotation.Nullable; import androidx.fragment.app.Fragment; import com.sec.android.app.genremusiconnx.R; import com.sec.android.app.genremusiconnx.ai.InstrumentClassifier; import com.sec.android.app.genremusiconnx.metadata.RecorderMetadata; import com.sec.android.app.genremusiconnx.util.BackgroundTask; import java.text.SimpleDateFormat; public class RecordFragment extends Fragment implements MediaRecorder.OnInfoListener { private TextView timeRecord; private Button recordControl; private Boolean isRecording; private ImageView buttonOpenList; private FrameLayout progressLayout; public RecordFragment() { } @Nullable @Override public View onCreateView(@NonNull LayoutInflater inflater, @Nullable ViewGroup container, @Nullable Bundle savedInstanceState) { View view = inflater.inflate(R.layout.record_fragment, container, false); initView(view); // InstrumentClassifier instrumentClassifier = new InstrumentClassifier(getContext()); // instrumentClassifier.inference("/storage/emulated/0/Ringtones/piano_4.wav"); return view; } private void initView(View view) { timeRecord = view.findViewById(R.id.time_record); recordControl = view.findViewById(R.id.control_record); buttonOpenList = view.findViewById(R.id.open_list); isRecording = false; progressLayout = view.findViewById(R.id.progress_layout); recordControl.setOnClickListener(v -> handleControlRecord()); buttonOpenList.setOnClickListener(v ->{ if (!isRecording) { FragmentController.getCurrentInstance().navigateTo(FragmentController.LIST_FRAGMENT); } }); } private void handleControlRecord() { if (isRecording) { stopRecord(); isRecording = false; } else { startRecord(); isRecording = true; } } private void startRecord() { recordControl.setText("Stop record"); Recorder.getInstance().startRecord(getContext(), this); RecorderMetadata.getInstance().setStartTime(System.currentTimeMillis()); Thread timer = new Thread(() -> { int tick = 0; while (isRecording) { int finalTick = tick; new Handler(Looper.getMainLooper()).post(() -> updateUi(finalTick)); tick += 1000; try { Thread.sleep(1000); } catch (InterruptedException e) { throw new RuntimeException(e); } } }); timer.start(); } private void updateUi(int tick) { SimpleDateFormat simpleDateFormat = new SimpleDateFormat("mm:ss"); timeRecord.setText(simpleDateFormat.format(tick)); } private void stopRecord() { isRecording = false; recordControl.setText("Start record"); Recorder.getInstance().stopRecord(); RecorderMetadata.getInstance().setEndTime(System.currentTimeMillis()); InstrumentClassifier modelClassifier = new InstrumentClassifier(getContext()); final int[] classifiedCategoryId = {-1}; BackgroundTask backgroundTask = new BackgroundTask(new BackgroundTask.BackgroundCallback() { @Override public void doInBackground() { getActivity().runOnUiThread(new Runnable() { @Override public void run() { progressLayout.setVisibility(View.VISIBLE); } }); classifiedCategoryId[0] = modelClassifier.inference(RecorderMetadata.getInstance().getPath()); } @Override public void onDone() { getActivity().runOnUiThread(() -> { progressLayout.setVisibility(View.GONE); RecorderMetadata.getInstance().setCategory(classifiedCategoryId[0]); RecorderMetadata.getInstance().saveToDatabase(getContext()); new Handler(Looper.getMainLooper()).postDelayed(() -> FragmentController.getCurrentInstance().navigateTo(FragmentController.LIST_FRAGMENT),500); }); } }); backgroundTask.start(); } @Override public void onInfo(MediaRecorder mr, int what, int extra) { Log.i("bacnv", "onInfo: " + extra); if (what == MediaRecorder.MEDIA_RECORDER_INFO_MAX_DURATION_REACHED) { stopRecord(); } } }
Editor is loading...
Leave a Comment